Skip to content

Commit dd3d0ae

Browse files
committed
Fix static cache
1 parent 83b3b6a commit dd3d0ae

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_tiny_llm_export_static(self):
5656
self.assertEqual(
5757
{"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs)
5858
)
59-
with torch_export_patches(patch_transformers=True, stop_if_static=1):
59+
with torch_export_patches(patch_transformers=True, stop_if_static=0):
6060
ep = torch.export.export(
6161
model,
6262
(),

onnx_diagnostic/tasks/text_generation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ def get_inputs(
176176
"attention_mask": {0: batch, 2: "seq"},
177177
"cache_position": {0: "seq"},
178178
"past_key_values": [
179-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
179+
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180+
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181+
[{0: batch} for _ in range(num_hidden_layers)],
182+
[{0: batch} for _ in range(num_hidden_layers)],
181183
],
182184
}
183185
inputs = dict(

0 commit comments

Comments
 (0)