File tree Expand file tree Collapse file tree 2 files changed +5
-2
lines changed
_unittests/ut_torch_models
onnx_diagnostic/torch_models/untrained Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Original file line number Diff line number Diff line change 11import copy
22import unittest
33import torch
4- from onnx_diagnostic .ext_test_case import ExtTestCase , ignore_warnings
4+ from onnx_diagnostic .ext_test_case import ExtTestCase , ignore_warnings , requires_transformers
55from onnx_diagnostic .torch_models .llms import get_tiny_llm
66from onnx_diagnostic .helpers import string_type
77from onnx_diagnostic .torch_export_patches import torch_export_patches
@@ -33,13 +33,15 @@ def test_tiny_llm_export_dynamic(self):
3333 got = ep .module ()(** inputs )
3434 self .assertEqualArrayAny (expected , got )
3535
36+ @requires_transformers ("4.52" )
3637 def test_tiny_llm_run_static (self ):
3738 data = get_tiny_llm (use_static_cache = True )
3839 model , inputs = data ["model" ], data ["inputs" ]
3940 self .assertIn ("StaticCache" , string_type (inputs ))
4041 model (** inputs )
4142
4243 @ignore_warnings (UserWarning )
44+ @requires_transformers ("4.52" )
4345 def test_tiny_llm_export_static (self ):
4446 data = get_tiny_llm (use_static_cache = True )
4547 model , inputs = data ["model" ], data ["inputs" ]
Original file line number Diff line number Diff line change @@ -50,7 +50,8 @@ def get_tiny_llm(
5050
5151 config .update (** kwargs )
5252 conf = transformers .LlamaConfig (** config )
53- conf .cache_implementation = "static"
53+ if use_static_cache :
54+ conf .cache_implementation = "static"
5455 model = transformers .LlamaForCausalLM (conf )
5556 model .eval ()
5657
You can’t perform that action at this time.
0 commit comments