Skip to content

Commit 50f72f1

Browse files
committed
mechanism for inputs2
1 parent e7af4b0 commit 50f72f1

13 files changed

+56
-6
lines changed

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def get_inputs(
132132
)
133133
res = dict(inputs=inputs, dynamic_shapes=shapes)
134134
if add_second_input:
135+
assert (
136+
add_second_input > 0
137+
), f"Not implemented for add_second_input={add_second_input}."
135138
res["inputs2"] = get_inputs(
136139
model=model,
137140
config=config,
@@ -145,6 +148,7 @@ def get_inputs(
145148
head_dim=head_dim,
146149
batch_size=batch_size + 1,
147150
sequence_length=sequence_length + add_second_input,
151+
add_second_input=0,
148152
**kwargs,
149153
)["inputs"]
150154
return res

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,16 @@ def get_inputs(
5252
)
5353
res = dict(inputs=inputs, dynamic_shapes=shapes)
5454
if add_second_input:
55+
assert (
56+
add_second_input > 0
57+
), f"Not implemented for add_second_input={add_second_input}."
5558
res["inputs2"] = get_inputs(
5659
model=model,
5760
config=config,
5861
batch_size=batch_size + 1,
5962
sequence_length=sequence_length + add_second_input,
6063
dummy_max_token_id=dummy_max_token_id,
64+
add_second_input=0,
6165
**kwargs,
6266
)["inputs"]
6367
return res

onnx_diagnostic/tasks/fill_mask.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,16 @@ def get_inputs(
5454
)
5555
res = dict(inputs=inputs, dynamic_shapes=shapes)
5656
if add_second_input:
57+
assert (
58+
add_second_input > 0
59+
), f"Not implemented for add_second_input={add_second_input}."
5760
res["inputs2"] = get_inputs(
5861
model=model,
5962
config=config,
6063
batch_size=batch_size + 1,
6164
sequence_length=sequence_length + add_second_input,
6265
dummy_max_token_id=dummy_max_token_id,
66+
add_second_input=0,
6367
**kwargs,
6468
)["inputs"]
6569
return res

onnx_diagnostic/tasks/image_classification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def get_inputs(
7575
shapes["interpolate_pos_encoding"] = None # type: ignore[assignment]
7676
res = dict(inputs=inputs, dynamic_shapes=shapes)
7777
if add_second_input:
78+
assert (
79+
add_second_input > 0
80+
), f"Not implemented for add_second_input={add_second_input}."
7881
res["inputs2"] = get_inputs(
7982
model=model,
8083
config=config,
@@ -83,6 +86,7 @@ def get_inputs(
8386
input_channels=input_channels,
8487
batch_size=batch_size + 1,
8588
dynamic_rope=dynamic_rope,
89+
add_second_input=0,
8690
**kwargs,
8791
)["inputs"]
8892
return res

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def get_inputs(
105105
)
106106
res = dict(inputs=inputs, dynamic_shapes=shapes)
107107
if add_second_input:
108+
assert (
109+
add_second_input > 0
110+
), f"Not implemented for add_second_input={add_second_input}."
108111
res["inputs2"] = get_inputs(
109112
model=model,
110113
config=config,
@@ -117,9 +120,10 @@ def get_inputs(
117120
num_channels=num_channels,
118121
batch_size=batch_size + 1,
119122
sequence_length=sequence_length + add_second_input,
120-
sequence_length2=sequence_length2 + add_second_input,
123+
sequence_length2=sequence_length2 + 1,
121124
n_images=n_images + 1,
122125
dynamic_rope=dynamic_rope,
126+
add_second_input=0,
123127
**kwargs,
124128
)["inputs"]
125129
return res

onnx_diagnostic/tasks/object_detection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def get_inputs(
6565
)
6666
res = dict(inputs=inputs, dynamic_shapes=shapes)
6767
if add_second_input:
68+
assert (
69+
add_second_input > 0
70+
), f"Not implemented for add_second_input={add_second_input}."
6871
res["inputs2"] = get_inputs(
6972
model=model,
7073
config=config,
@@ -73,6 +76,7 @@ def get_inputs(
7376
input_channels=input_channels,
7477
batch_size=batch_size + 1,
7578
dynamic_rope=dynamic_rope,
79+
add_second_input=0,
7680
**kwargs,
7781
)["inputs"]
7882
return res

onnx_diagnostic/tasks/sentence_similarity.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,16 @@ def get_inputs(
5454
)
5555
res = dict(inputs=inputs, dynamic_shapes=shapes)
5656
if add_second_input:
57+
assert (
58+
add_second_input > 0
59+
), f"Not implemented for add_second_input={add_second_input}."
5760
res["inputs2"] = get_inputs(
5861
model=model,
5962
config=config,
6063
batch_size=batch_size + 1,
6164
sequence_length=sequence_length + add_second_input,
6265
dummy_max_token_id=dummy_max_token_id,
66+
add_second_input=0,
6367
**kwargs,
6468
)["inputs"]
6569
return res

onnx_diagnostic/tasks/summarization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def get_inputs(
144144
)
145145
res = dict(inputs=inputs, dynamic_shapes=shapes)
146146
if add_second_input:
147+
assert (
148+
add_second_input > 0
149+
), f"Not implemented for add_second_input={add_second_input}."
147150
res["inputs2"] = get_inputs(
148151
model=model,
149152
config=config,
@@ -155,7 +158,8 @@ def get_inputs(
155158
head_dim_decoder=head_dim_decoder,
156159
batch_size=batch_size + 1,
157160
sequence_length=sequence_length + add_second_input,
158-
sequence_length2=sequence_length2 + add_second_input,
161+
sequence_length2=sequence_length2 + 1,
162+
add_second_input=0,
159163
**kwargs,
160164
)["inputs"]
161165
return res

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def get_inputs(
149149
)
150150
res = dict(inputs=inputs, dynamic_shapes=shapes)
151151
if add_second_input:
152+
assert (
153+
add_second_input > 0
154+
), f"Not implemented for add_second_input={add_second_input}."
152155
res["inputs2"] = get_inputs(
153156
model=model,
154157
config=config,
@@ -161,7 +164,8 @@ def get_inputs(
161164
encoder_dim=encoder_dim,
162165
batch_size=batch_size + 1,
163166
sequence_length=sequence_length + add_second_input,
164-
sequence_length2=sequence_length2 + add_second_input,
167+
sequence_length2=sequence_length2 + 1,
168+
add_second_input=0,
165169
**kwargs,
166170
)["inputs"]
167171
return res

onnx_diagnostic/tasks/text_classification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,16 @@ def get_inputs(
5454
)
5555
res = dict(inputs=inputs, dynamic_shapes=shapes)
5656
if add_second_input:
57+
assert (
58+
add_second_input > 0
59+
), f"Not implemented for add_second_input={add_second_input}."
5760
res["inputs2"] = get_inputs(
5861
model=model,
5962
config=config,
6063
batch_size=batch_size + 1,
6164
sequence_length=sequence_length + add_second_input,
6265
dummy_max_token_id=dummy_max_token_id,
66+
add_second_input=0,
6367
**kwargs,
6468
)["inputs"]
6569
return res

0 commit comments

Comments
 (0)