Skip to content

Commit 75456d7

Browse files
committed
add second input
1 parent 9df1cdc commit 75456d7

File tree

12 files changed

+250
-76
lines changed

12 files changed

+250
-76
lines changed

_doc/api/tasks/index.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ All submodules contains the three following functions:
99
* ``random_input_kwargs(config) -> kwargs, get_inputs``:
1010
produces values ``get_inputs`` can take to generate dummy inputs
1111
suitable for a model defined by its configuration
12-
* ``get_inputs(model, config, *args, **kwargs) -> dict(inputs=..., dynamic_shapes=...)``:
13-
generates the dummy inputs and dynamic shapes for a specific model and configuration.
12+
* ``get_inputs(model, config, *args, add_second_input=False, **kwargs) -> dict(inputs=..., dynamic_shapes=...)``:
13+
generates the dummy inputs and dynamic shapes for a specific model and configuration,
14+
if ``add_second_input`` is True, the function should return a different set of inputs,
15+
with different values for the dynamic dimension. This is usually better to
16+
rely on the function as the dynamic dimensions may be correlated.
1417

1518
For a specific task, you would write:
1619

_unittests/ut_tasks/test_tasks.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,41 @@ class TestTasks(ExtTestCase):
1010
@hide_stdout()
1111
def test_text2text_generation(self):
1212
mid = "sshleifer/tiny-marian-en-de"
13-
data = get_untrained_model_with_inputs(mid, verbose=1)
13+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
14+
self.assertEqual(data["task"], "text2text-generation")
1415
self.assertIn((data["size"], data["n_weights"]), [(473928, 118482)])
1516
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
1617
raise unittest.SkipTest(f"not working for {mid!r}")
1718
model(**inputs)
19+
model(**data["inputs2"])
20+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
21+
torch.export.export(
22+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
23+
)
24+
25+
@hide_stdout()
26+
def test_text_generation(self):
27+
mid = "arnir0/Tiny-LLM"
28+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
29+
self.assertEqual(data["task"], "text-generation")
30+
self.assertIn((data["size"], data["n_weights"]), [(51955968, 12988992)])
31+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
32+
model(**inputs)
33+
model(**data["inputs2"])
34+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
35+
torch.export.export(
36+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
37+
)
38+
39+
@hide_stdout()
40+
def test_image_classification(self):
41+
mid = "hf-internal-testing/tiny-random-BeitForImageClassification"
42+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
43+
self.assertEqual(data["task"], "image-classification")
44+
self.assertIn((data["size"], data["n_weights"]), [(56880, 14220)])
45+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
46+
model(**inputs)
47+
model(**data["inputs2"])
1848
with bypass_export_some_errors(patch_transformers=True, verbose=10):
1949
torch.export.export(
2050
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -23,9 +53,11 @@ def test_text2text_generation(self):
2353
@hide_stdout()
2454
def test_automatic_speech_recognition(self):
2555
mid = "openai/whisper-tiny"
26-
data = get_untrained_model_with_inputs(mid, verbose=1)
56+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
57+
self.assertEqual(data["task"], "automatic-speech-recognition")
2758
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
2859
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
60+
model(**data["inputs2"])
2961
Dim = torch.export.Dim
3062
self.maxDiff = None
3163
self.assertIn("{0:Dim(batch),1:DYN(seq_length)}", self.string_type(ds))
@@ -91,13 +123,15 @@ def test_automatic_speech_recognition(self):
91123
)
92124

93125
@hide_stdout()
94-
def test_imagetext2text_generation(self):
126+
def test_image_text_to_text(self):
95127
mid = "HuggingFaceM4/tiny-random-idefics"
96-
data = get_untrained_model_with_inputs(mid, verbose=1)
128+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
129+
self.assertEqual(data["task"], "image-text-to-text")
97130
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
98131
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
99132
model(**inputs)
100-
if not has_torch("2.10"):
133+
model(**data["inputs2"])
134+
if not has_torch("2.8"):
101135
raise unittest.SkipTest("sym_max does not work with dynamic dimension")
102136
with bypass_export_some_errors(patch_transformers=True, verbose=10):
103137
torch.export.export(
@@ -107,10 +141,12 @@ def test_imagetext2text_generation(self):
107141
@hide_stdout()
108142
def test_fill_mask(self):
109143
mid = "google-bert/bert-base-multilingual-cased"
110-
data = get_untrained_model_with_inputs(mid, verbose=1)
144+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
145+
self.assertEqual(data["task"], "fill-mask")
111146
self.assertIn((data["size"], data["n_weights"]), [(428383212, 107095803)])
112147
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
113148
model(**inputs)
149+
model(**data["inputs2"])
114150
with bypass_export_some_errors(patch_transformers=True, verbose=10):
115151
torch.export.export(
116152
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -119,10 +155,12 @@ def test_fill_mask(self):
119155
@hide_stdout()
120156
def test_feature_extraction(self):
121157
mid = "facebook/bart-base"
122-
data = get_untrained_model_with_inputs(mid, verbose=1)
158+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
159+
self.assertEqual(data["task"], "feature-extraction")
123160
self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)])
124161
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
125162
model(**inputs)
163+
model(**data["inputs2"])
126164
with bypass_export_some_errors(patch_transformers=True, verbose=10):
127165
torch.export.export(
128166
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -131,10 +169,12 @@ def test_feature_extraction(self):
131169
@hide_stdout()
132170
def test_text_classification(self):
133171
mid = "Intel/bert-base-uncased-mrpc"
134-
data = get_untrained_model_with_inputs(mid, verbose=1)
172+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
173+
self.assertEqual(data["task"], "text-classification")
135174
self.assertIn((data["size"], data["n_weights"]), [(154420232, 38605058)])
136175
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
137176
model(**inputs)
177+
model(**data["inputs2"])
138178
with bypass_export_some_errors(patch_transformers=True, verbose=10):
139179
torch.export.export(
140180
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -143,10 +183,12 @@ def test_text_classification(self):
143183
@hide_stdout()
144184
def test_sentence_similary(self):
145185
mid = "sentence-transformers/all-MiniLM-L6-v1"
146-
data = get_untrained_model_with_inputs(mid, verbose=1)
186+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
187+
self.assertEqual(data["task"], "sentence-similarity")
147188
self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)])
148189
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
149190
model(**inputs)
191+
model(**data["inputs2"])
150192
with bypass_export_some_errors(patch_transformers=True, verbose=10):
151193
torch.export.export(
152194
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -155,9 +197,11 @@ def test_sentence_similary(self):
155197
@hide_stdout()
156198
def test_falcon_mamba_dev(self):
157199
mid = "tiiuae/falcon-mamba-tiny-dev"
158-
data = get_untrained_model_with_inputs(mid, verbose=1)
200+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
201+
self.assertEqual(data["task"], "text-generation")
159202
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
160203
model(**inputs)
204+
model(**data["inputs2"])
161205
self.assertIn((data["size"], data["n_weights"]), [(138640384, 34660096)])
162206
if not has_transformers("4.55"):
163207
raise unittest.SkipTest("The model has control flow.")
@@ -166,6 +210,20 @@ def test_falcon_mamba_dev(self):
166210
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
167211
)
168212

213+
@hide_stdout()
214+
def test_zero_shot_image_classification(self):
215+
mid = "openai/clip-vit-base-patch16"
216+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
217+
self.assertEqual(data["task"], "zero-shot-image-classification")
218+
self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)])
219+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
220+
model(**inputs)
221+
model(**data["inputs2"])
222+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
223+
torch.export.export(
224+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
225+
)
226+
169227

170228
if __name__ == "__main__":
171229
unittest.main(verbosity=2)

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def get_inputs(
6969
use_cache:bool,return_dict:bool
7070
)
7171
"""
72-
assert not add_second_input, "add_second_input=True not yet implemented"
7372
batch = torch.export.Dim("batch", min=1, max=1024)
7473
seq_length = "seq_length"
7574

@@ -128,7 +127,24 @@ def get_inputs(
128127
# encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
129128
# encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
130129
)
131-
return dict(inputs=inputs, dynamic_shapes=shapes)
130+
res = dict(inputs=inputs, dynamic_shapes=shapes)
131+
if add_second_input:
132+
res["inputs2"] = get_inputs(
133+
model=model,
134+
config=config,
135+
dummy_max_token_id=dummy_max_token_id,
136+
max_source_positions=max_source_positions,
137+
d_model=d_model,
138+
num_hidden_layers=num_hidden_layers,
139+
encoder_attention_heads=encoder_attention_heads,
140+
encoder_layers=encoder_layers,
141+
decoder_layers=decoder_layers,
142+
head_dim=head_dim,
143+
batch_size=batch_size + 1,
144+
sequence_length=sequence_length + 1,
145+
**kwargs,
146+
)["inputs"]
147+
return res
132148

133149

134150
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def get_inputs(
3535
token_type_ids:T7s1x13[0,0:A0.0],
3636
attention_mask:T7s1x13[1,1:A1.0])
3737
"""
38-
assert not add_second_input, "add_second_input=True not yet implemented"
3938
batch = torch.export.Dim("batch", min=1, max=1024)
4039
seq_length = "sequence_length"
4140
shapes = {
@@ -48,7 +47,17 @@ def get_inputs(
4847
),
4948
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
5049
)
51-
return dict(inputs=inputs, dynamic_shapes=shapes)
50+
res = dict(inputs=inputs, dynamic_shapes=shapes)
51+
if add_second_input:
52+
res["inputs2"] = get_inputs(
53+
model=model,
54+
config=config,
55+
batch_size=batch_size + 1,
56+
sequence_length=sequence_length + 1,
57+
dummy_max_token_id=dummy_max_token_id,
58+
**kwargs,
59+
)["inputs"]
60+
return res
5261

5362

5463
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

onnx_diagnostic/tasks/fill_mask.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def get_inputs(
3535
token_type_ids:T7s1x13[0,0:A0.0],
3636
attention_mask:T7s1x13[1,1:A1.0])
3737
"""
38-
assert not add_second_input, "add_second_input=True not yet implemented"
3938
batch = torch.export.Dim("batch", min=1, max=1024)
4039
seq_length = "sequence_length"
4140
shapes = {
@@ -50,7 +49,17 @@ def get_inputs(
5049
token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
5150
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
5251
)
53-
return dict(inputs=inputs, dynamic_shapes=shapes)
52+
res = dict(inputs=inputs, dynamic_shapes=shapes)
53+
if add_second_input:
54+
res["inputs2"] = get_inputs(
55+
model=model,
56+
config=config,
57+
batch_size=batch_size + 1,
58+
sequence_length=sequence_length + 1,
59+
dummy_max_token_id=dummy_max_token_id,
60+
**kwargs,
61+
)["inputs"]
62+
return res
5463

5564

5665
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

onnx_diagnostic/tasks/image_classification.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def get_inputs(
4141
:param input_height: input height
4242
:return: dictionary
4343
"""
44-
assert not add_second_input, "add_second_input=True not yet implemented"
4544
assert isinstance(
4645
input_width, int
4746
), f"Unexpected type for input_width {type(input_width)}{config}"
@@ -61,7 +60,19 @@ def get_inputs(
6160
-1, 1
6261
),
6362
)
64-
return dict(inputs=inputs, dynamic_shapes=shapes)
63+
res = dict(inputs=inputs, dynamic_shapes=shapes)
64+
if add_second_input:
65+
res["inputs2"] = get_inputs(
66+
model=model,
67+
config=config,
68+
input_width=input_width + 1,
69+
input_height=input_height + 1,
70+
input_channels=input_channels,
71+
batch_size=batch_size + 1,
72+
dynamic_rope=dynamic_rope,
73+
**kwargs,
74+
)["inputs"]
75+
return res
6576

6677

6778
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,26 @@ def get_inputs(
100100
torch.int64
101101
),
102102
)
103-
return dict(inputs=inputs, dynamic_shapes=shapes)
103+
res = dict(inputs=inputs, dynamic_shapes=shapes)
104+
if add_second_input:
105+
res["inputs2"] = get_inputs(
106+
model=model,
107+
config=config,
108+
dummy_max_token_id=dummy_max_token_id,
109+
num_key_value_heads=num_key_value_heads,
110+
num_hidden_layers=num_hidden_layers,
111+
head_dim=head_dim,
112+
width=width,
113+
height=height,
114+
num_channels=num_channels,
115+
batch_size=batch_size + 1,
116+
sequence_length=sequence_length + 1,
117+
sequence_length2=sequence_length2 + 1,
118+
n_images=n_images + 1,
119+
dynamic_rope=dynamic_rope,
120+
**kwargs,
121+
)["inputs"]
122+
return res
104123

105124

106125
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

onnx_diagnostic/tasks/sentence_similarity.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def get_inputs(
3535
token_type_ids:T7s1x13[0,0:A0.0],
3636
attention_mask:T7s1x13[1,1:A1.0])
3737
"""
38-
assert not add_second_input, "add_second_input=True not yet implemented"
3938
batch = torch.export.Dim("batch", min=1, max=1024)
4039
seq_length = "seq_length"
4140
shapes = {
@@ -50,7 +49,17 @@ def get_inputs(
5049
token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
5150
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
5251
)
53-
return dict(inputs=inputs, dynamic_shapes=shapes)
52+
res = dict(inputs=inputs, dynamic_shapes=shapes)
53+
if add_second_input:
54+
res["inputs2"] = get_inputs(
55+
model=model,
56+
config=config,
57+
batch_size=batch_size + 1,
58+
sequence_length=sequence_length + 1,
59+
dummy_max_token_id=dummy_max_token_id,
60+
**kwargs,
61+
)["inputs"]
62+
return res
5463

5564

5665
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,22 @@ def get_inputs(
126126
# encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
127127
# encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
128128
)
129-
return dict(inputs=inputs, dynamic_shapes=shapes)
129+
res = dict(inputs=inputs, dynamic_shapes=shapes)
130+
if add_second_input:
131+
res["inputs2"] = get_inputs(
132+
model=model,
133+
config=config,
134+
dummy_max_token_id=dummy_max_token_id,
135+
num_key_value_heads=num_key_value_heads,
136+
num_hidden_layers=num_hidden_layers,
137+
head_dim=head_dim,
138+
encoder_dim=encoder_dim,
139+
batch_size=batch_size + 1,
140+
sequence_length=sequence_length + 1,
141+
sequence_length2=sequence_length2 + 1,
142+
**kwargs,
143+
)["inputs"]
144+
return res
130145

131146

132147
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

0 commit comments

Comments
 (0)