Skip to content

Commit 9df1cdc

Browse files
committed
inputs2
1 parent 5fb1fb9 commit 9df1cdc

13 files changed

+112
-2
lines changed

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def get_inputs(
3333
head_dim: int,
3434
batch_size: int = 2,
3535
sequence_length: int = 30,
36+
add_second_input: bool = False,
3637
**kwargs, # unused
3738
):
3839
"""
@@ -68,6 +69,7 @@ def get_inputs(
6869
use_cache:bool,return_dict:bool
6970
)
7071
"""
72+
assert not add_second_input, "add_second_input=True not yet implemented"
7173
batch = torch.export.Dim("batch", min=1, max=1024)
7274
seq_length = "seq_length"
7375

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_inputs(
2222
batch_size: int,
2323
sequence_length: int,
2424
dummy_max_token_id: int,
25+
add_second_input: bool = False,
2526
**kwargs, # unused
2627
):
2728
"""
@@ -34,6 +35,7 @@ def get_inputs(
3435
token_type_ids:T7s1x13[0,0:A0.0],
3536
attention_mask:T7s1x13[1,1:A1.0])
3637
"""
38+
assert not add_second_input, "add_second_input=True not yet implemented"
3739
batch = torch.export.Dim("batch", min=1, max=1024)
3840
seq_length = "sequence_length"
3941
shapes = {

onnx_diagnostic/tasks/fill_mask.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_inputs(
2222
batch_size: int,
2323
sequence_length: int,
2424
dummy_max_token_id: int,
25+
add_second_input: bool = False,
2526
**kwargs, # unused
2627
):
2728
"""
@@ -34,6 +35,7 @@ def get_inputs(
3435
token_type_ids:T7s1x13[0,0:A0.0],
3536
attention_mask:T7s1x13[1,1:A1.0])
3637
"""
38+
assert not add_second_input, "add_second_input=True not yet implemented"
3739
batch = torch.export.Dim("batch", min=1, max=1024)
3840
seq_length = "sequence_length"
3941
shapes = {

onnx_diagnostic/tasks/image_classification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_inputs(
2727
input_channels: int,
2828
batch_size: int = 2,
2929
dynamic_rope: bool = False,
30+
add_second_input: bool = False,
3031
**kwargs, # unused
3132
):
3233
"""
@@ -40,6 +41,7 @@ def get_inputs(
4041
:param input_height: input height
4142
:return: dictionary
4243
"""
44+
assert not add_second_input, "add_second_input=True not yet implemented"
4345
assert isinstance(
4446
input_width, int
4547
), f"Unexpected type for input_width {type(input_width)}{config}"

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_inputs(
3232
sequence_length2: int = 3,
3333
n_images: int = 2,
3434
dynamic_rope: bool = False,
35+
add_second_input: bool = False,
3536
**kwargs, # unused
3637
):
3738
"""

onnx_diagnostic/tasks/mixture_of_expert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_inputs(
4141
sequence_length2: int = 3,
4242
n_images: int = 2,
4343
dynamic_rope: bool = False,
44+
add_second_input: bool = False,
4445
**kwargs, # unused
4546
):
4647
"""
@@ -60,6 +61,7 @@ def get_inputs(
6061
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
6162
:return: dictionary
6263
"""
64+
assert not add_second_input, "add_second_input=True not yet implemented"
6365
raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
6466

6567

onnx_diagnostic/tasks/sentence_similarity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_inputs(
2222
batch_size: int,
2323
sequence_length: int,
2424
dummy_max_token_id: int,
25+
add_second_input: bool = False,
2526
**kwargs, # unused
2627
):
2728
"""
@@ -34,6 +35,7 @@ def get_inputs(
3435
token_type_ids:T7s1x13[0,0:A0.0],
3536
attention_mask:T7s1x13[1,1:A1.0])
3637
"""
38+
assert not add_second_input, "add_second_input=True not yet implemented"
3739
batch = torch.export.Dim("batch", min=1, max=1024)
3840
seq_length = "seq_length"
3941
shapes = {

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_inputs(
2828
batch_size: int = 2,
2929
sequence_length: int = 30,
3030
sequence_length2: int = 3,
31+
add_second_input: bool = False,
3132
**kwargs, # unused
3233
):
3334
"""

onnx_diagnostic/tasks/text_classification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_inputs(
2222
batch_size: int,
2323
sequence_length: int,
2424
dummy_max_token_id: int,
25+
add_second_input: bool = False,
2526
**kwargs, # unused
2627
):
2728
"""
@@ -34,6 +35,7 @@ def get_inputs(
3435
token_type_ids:T7s1x13[0,0:A0.0],
3536
attention_mask:T7s1x13[1,1:A1.0])
3637
"""
38+
assert not add_second_input, "add_second_input=True not yet implemented"
3739
batch = torch.export.Dim("batch", min=1, max=1024)
3840
seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
3941
shapes = {

onnx_diagnostic/tasks/text_generation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def get_inputs(
7171
num_key_value_heads: Optional[int] = None,
7272
head_dim: Optional[int] = None,
7373
cls_cache: Optional[Union[type, str]] = None,
74+
add_second_input: bool = False,
7475
**kwargs, # unused
7576
):
7677
"""
@@ -88,6 +89,7 @@ def get_inputs(
8889
:class:`transformers.cache_utils.DynamicCache`
8990
:return: dictionary
9091
"""
92+
assert not add_second_input, "add_second_input=True not yet implemented"
9193
batch = torch.export.Dim("batch", min=1, max=1024)
9294
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
9395
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -192,7 +194,23 @@ def get_inputs(
192194
]
193195
),
194196
)
195-
return dict(inputs=inputs, dynamic_shapes=shapes)
197+
res = dict(inputs=inputs, dynamic_shapes=shapes)
198+
if add_second_input:
199+
res["inputs2"] = get_inputs(
200+
model=model,
201+
config=config,
202+
dummy_max_token_id=dummy_max_token_id,
203+
num_hidden_layers=num_hidden_layers,
204+
batch_size=batch_size + 1,
205+
sequence_length=sequence_length + 1,
206+
sequence_length2=sequence_length2 + 1,
207+
dynamic_rope=dynamic_rope,
208+
num_key_value_heads=num_key_value_heads,
209+
head_dim=head_dim,
210+
cls_cache=cls_cache,
211+
**kwargs,
212+
)
213+
return res
196214

197215

198216
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

0 commit comments

Comments
 (0)