Skip to content

Commit e1a8e1c

Browse files
committed
fix cache
1 parent 63a1408 commit e1a8e1c

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def assertEqualArrayAny(
10141014
msg_ = "\n".join(excs)
10151015
msg = f"{msg}\n{msg_}" if msg else msg_
10161016
raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
1017-
elif expected.__class__.__name__ == "DynamicCache":
1017+
elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
10181018
atts = {"key_cache", "value_cache"}
10191019
self.assertEqualArrayAny(
10201020
{k: expected.__dict__.get(k, None) for k in atts},

onnx_diagnostic/tasks/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def get_inputs(
174174
shapes = {
175175
"input_ids": {0: batch, 1: seq_length},
176176
"attention_mask": {0: batch, 2: "seq"},
177-
"cache_position": {1: "seq"},
177+
"cache_position": {0: "seq"},
178178
"past_key_values": [
179179
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180180
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],

onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ def get_tiny_llm(
5757
res = get_inputs(
5858
model,
5959
conf,
60-
dummy_max_token_id=config["vocab_size"],
61-
num_hidden_layers=config["num_hidden_layers"],
60+
dummy_max_token_id=config["vocab_size"], # type: ignore[arg-type]
61+
num_hidden_layers=config["num_hidden_layers"], # type: ignore[arg-type]
6262
batch_size=batch_size,
6363
sequence_length=sequence_length,
6464
sequence_length2=sequence_length2,
6565
dynamic_rope=dynamic_rope,
66-
num_key_value_heads=config["num_key_value_heads"],
66+
num_key_value_heads=config["num_key_value_heads"], # type: ignore[arg-type]
6767
cls_cache="StaticCache" if use_static_cache else "DynamicCache",
6868
)
6969

0 commit comments

Comments
 (0)