Skip to content

Commit f15360e

Browse files
committed
fix static cahce
1 parent d817f19 commit f15360e

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

onnx_diagnostic/tasks/text_generation.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,37 +173,34 @@ def get_inputs(
173173
# static
174174
shapes = {
175175
"input_ids": {0: batch, 1: seq_length},
176-
"attention_mask": {0: batch, 2: "sequence_length+past_sequence_length"},
177-
"cache_position": {0: "sequence_length+past_sequence_length"},
176+
"attention_mask": {0: batch, 2: "past_sequence_length"},
177+
"cache_position": {0: "past_sequence_length"},
178178
"past_key_values": [
179-
# [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
180-
# [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
179+
# past_sequence_length is now static
181180
[{0: batch} for _ in range(num_hidden_layers)],
182181
[{0: batch} for _ in range(num_hidden_layers)],
183182
],
184183
}
185184
inputs = dict(
186185
input_ids=torch.randint(
187-
0, dummy_max_token_id, (batch_size, sequence_length)
186+
0, dummy_max_token_id, (batch_size, past_sequence_length)
188187
).to(torch.int64),
189188
attention_mask=torch.ones(
190189
(
191190
batch_size,
192191
num_key_value_heads,
193-
past_sequence_length + sequence_length,
192+
past_sequence_length,
194193
head_dim,
195194
)
196195
).to(torch.bool),
197-
cache_position=torch.arange(past_sequence_length + sequence_length).to(
198-
torch.int64
199-
),
196+
cache_position=torch.arange(past_sequence_length).to(torch.int64),
200197
past_key_values=make_static_cache(
201198
[
202199
(
203200
torch.randn(
204201
batch_size,
205202
num_key_value_heads,
206-
past_sequence_length + sequence_length,
203+
sequence_length + past_sequence_length,
207204
head_dim,
208205
),
209206
torch.randn(
@@ -215,7 +212,7 @@ def get_inputs(
215212
)
216213
for i in range(num_hidden_layers)
217214
],
218-
max_cache_len=max(sequence_length + past_sequence_length, head_dim),
215+
max_cache_len=max(past_sequence_length, head_dim),
219216
),
220217
)
221218
else:

0 commit comments

Comments
 (0)