Skip to content

Commit 5195875

Browse files
committed
fix
1 parent 3251a8c commit 5195875

File tree

7 files changed

+16
-50
lines changed

7 files changed

+16
-50
lines changed

_unittests/ut_export/test_shape_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,10 @@ def test_make_fake_with_dynamic_dimensions_whole(self):
225225
"attention_mask": {0: "batch", 1: "cache+seq"},
226226
"position_ids": {0: "batch", 1: "seq_length"},
227227
"past_key_values": [
228-
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
229-
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
228+
{0: "batch", 2: "cache_length"},
229+
{0: "batch", 2: "cache_length"},
230+
{0: "batch", 2: "cache_length"},
231+
{0: "batch", 2: "cache_length"},
230232
],
231233
},
232234
)

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,8 @@ def get_inputs(
8484
"cache_position": {0: seq_length},
8585
"encoder_outputs": [{0: batch}], # last_hidden_state
8686
"past_key_values": [
87-
[
88-
[{0: batch} for _ in range(num_hidden_layers)],
89-
[{0: batch} for _ in range(num_hidden_layers)],
90-
],
91-
[
92-
[{0: batch} for _ in range(num_hidden_layers)],
93-
[{0: batch} for _ in range(num_hidden_layers)],
94-
],
87+
[{0: batch} for _ in range(num_hidden_layers * 2)],
88+
[{0: batch} for _ in range(num_hidden_layers * 2)],
9589
],
9690
}
9791
inputs = dict(

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,8 @@ def get_inputs(
109109
cache_length = "cache_length_key"
110110
cache_length2 = "cache_length_val"
111111
shapes["past_key_values"] = [ # type: ignore[assignment]
112-
[
113-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
114-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
115-
],
116-
[
117-
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
118-
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
119-
],
112+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
113+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
120114
]
121115

122116
res = dict(inputs=inputs, dynamic_shapes=shapes)

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,7 @@ def _get_inputs_gemma3(
151151
},
152152
"position_ids": {0: batch, 1: seq_length},
153153
"cache_position": {0: seq_length},
154-
"past_key_values": [
155-
[{0: batch} for _ in range(num_hidden_layers)],
156-
[{0: batch} for _ in range(num_hidden_layers)],
157-
],
154+
"past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
158155
"pixel_values": {0: batch},
159156
"use_cache": None,
160157
}

onnx_diagnostic/tasks/summarization.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,8 @@ def get_inputs(
8181
"attention_mask": {0: batch, 1: "seq_mask"},
8282
# "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
8383
"past_key_values": [
84-
[
85-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
86-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
87-
],
88-
[
89-
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
90-
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
91-
],
84+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
85+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
9286
],
9387
# one these is selected based on the forward method signature
9488
# "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,8 @@ def get_inputs(
8383
"attention_mask": {0: batch, 1: "seq_mask"},
8484
# "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
8585
"past_key_values": [
86-
[
87-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
88-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
89-
],
90-
[
91-
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
92-
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
93-
],
86+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
87+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
9488
],
9589
# one these is selected based on the forward method signature
9690
# "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},

onnx_diagnostic/tasks/text_generation.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,7 @@ def get_inputs(
119119
0: batch,
120120
1: "cache+seq", # cache_length + seq_length
121121
},
122-
"cache_params": [
123-
[{0: batch} for _ in range(num_hidden_layers)],
124-
[{0: batch} for _ in range(num_hidden_layers)],
125-
],
122+
"cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
126123
}
127124
inputs = dict(
128125
input_ids=torch.randint(
@@ -176,12 +173,7 @@ def get_inputs(
176173
"input_ids": {0: batch, 1: seq_length},
177174
"attention_mask": {0: batch, 2: "seq"},
178175
"cache_position": {0: "seq"},
179-
"past_key_values": [
180-
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181-
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
182-
[{0: batch} for _ in range(num_hidden_layers)],
183-
[{0: batch} for _ in range(num_hidden_layers)],
184-
],
176+
"past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
185177
}
186178
inputs = dict(
187179
input_ids=torch.randint(
@@ -222,8 +214,7 @@ def get_inputs(
222214
},
223215
"position_ids": {0: batch, 1: seq_length},
224216
"past_key_values": [
225-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
226-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
217+
{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)
227218
],
228219
}
229220

0 commit comments

Comments
 (0)