Skip to content

Commit 4d2ca88

Browse files
committed
fix static
1 parent fb1844b commit 4d2ca88

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import unittest
33
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
55
from onnx_diagnostic.torch_models.llms import get_tiny_llm
66
from onnx_diagnostic.helpers import string_type
77
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -33,13 +33,15 @@ def test_tiny_llm_export_dynamic(self):
3333
got = ep.module()(**inputs)
3434
self.assertEqualArrayAny(expected, got)
3535

36+
@requires_transformers("4.52")
3637
def test_tiny_llm_run_static(self):
3738
data = get_tiny_llm(use_static_cache=True)
3839
model, inputs = data["model"], data["inputs"]
3940
self.assertIn("StaticCache", string_type(inputs))
4041
model(**inputs)
4142

4243
@ignore_warnings(UserWarning)
44+
@requires_transformers("4.52")
4345
def test_tiny_llm_export_static(self):
4446
data = get_tiny_llm(use_static_cache=True)
4547
model, inputs = data["model"], data["inputs"]

onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def get_tiny_llm(
5050

5151
config.update(**kwargs)
5252
conf = transformers.LlamaConfig(**config)
53-
conf.cache_implementation = "static"
53+
if use_static_cache:
54+
conf.cache_implementation = "static"
5455
model = transformers.LlamaForCausalLM(conf)
5556
model.eval()
5657

0 commit comments

Comments
 (0)