Skip to content

Commit 9ff0014

Browse files
committed
Using string instead of Dim
1 parent 1eab135 commit 9ff0014

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,19 @@ def get_inputs(
5252
:return: dictionary
5353
"""
5454
batch = torch.export.Dim("batch", min=1, max=1024)
55-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
56-
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
57-
images = torch.export.Dim("images", min=1, max=4096)
55+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
56+
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
57+
images = "images" # torch.export.Dim("images", min=1, max=4096)
5858

5959
shapes = {
6060
"input_ids": {0: batch, 1: seq_length},
6161
"attention_mask": {
6262
0: batch,
63-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
63+
1: "cache+seq", # cache_length + seq_length
6464
},
6565
"position_ids": {
6666
0: batch,
67-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
67+
1: "cache+seq", # cache_length + seq_length
6868
},
6969
"past_key_values": [
7070
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ def get_inputs(
5959
encoder_outputs:dict(last_hidden_state:T1s1x16x512)
6060
"""
6161
batch = torch.export.Dim("batch", min=1, max=1024)
62-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
63-
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
64-
cache_length2 = torch.export.Dim("cache_length2", min=1, max=4096)
62+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
63+
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
64+
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
6565

6666
shapes = {
6767
"input_ids": {0: batch, 1: seq_length},
68-
"decoder_input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC},
69-
"attention_mask": {0: batch, 1: torch.export.Dim.DYNAMIC},
68+
"decoder_input_ids": {0: batch, 1: "seq_ids"},
69+
"attention_mask": {0: batch, 1: "seq_mask"},
7070
# "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
7171
"past_key_values": [
7272
[

onnx_diagnostic/tasks/text_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_inputs(
3535
attention_mask:T7s1x13[1,1:A1.0])
3636
"""
3737
batch = torch.export.Dim("batch", min=1, max=1024)
38-
seq_length = torch.export.Dim("sequence_length", min=1, max=1024)
38+
seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
3939
shapes = {
4040
"input_ids": {0: batch, 1: seq_length},
4141
"token_type_ids": {0: batch, 1: seq_length},

onnx_diagnostic/tasks/text_generation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def get_inputs(
8484
:return: dictionary
8585
"""
8686
batch = torch.export.Dim("batch", min=1, max=1024)
87-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
88-
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
87+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
88+
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
8989

9090
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
9191
seq_length_multiple = 8
@@ -101,11 +101,11 @@ def get_inputs(
101101
"input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC},
102102
"attention_mask": {
103103
0: batch,
104-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
104+
1: "cache+seq", # cache_length + seq_length
105105
},
106106
"cache_position": {
107107
0: batch,
108-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
108+
1: "cache+seq", # cache_length + seq_length
109109
},
110110
"cache_params": [
111111
[{0: batch} for _ in range(num_hidden_layers)],
@@ -145,11 +145,11 @@ def get_inputs(
145145
"input_ids": {0: batch, 1: seq_length},
146146
"attention_mask": {
147147
0: batch,
148-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
148+
1: "cache+seq", # cache_length + seq_length
149149
},
150150
"position_ids": {
151151
0: batch,
152-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
152+
1: "cache+seq", # cache_length + seq_length
153153
},
154154
"past_key_values": [
155155
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],

0 commit comments

Comments
 (0)