1313# limitations under the License.
1414
1515import re
16+ import shutil
1617from typing import Optional
1718from unittest .mock import MagicMock
1819
3435from transformers import AutoModelForCausalLM
3536
3637
38+ @pytest .fixture (scope = "module" , autouse = True )
39+ def cleanup_model_cache ():
40+ """Clean up the test model cache directory after all tests complete."""
41+ yield
42+ try :
43+ shutil .rmtree ("test-apply-model-cache" , ignore_errors = True )
44+ except Exception :
45+ pass
46+
47+
3748@pytest .fixture
3849def mock_model ():
3950 model = MagicMock ()
@@ -55,6 +66,7 @@ def llama_stories_model():
5566 return AutoModelForCausalLM .from_pretrained (
5667 "Xenova/llama2.c-stories15M" ,
5768 torch_dtype = "auto" ,
69+ cache_dir = "test-apply-model-cache" ,
5870 )
5971
6072
@@ -87,7 +99,8 @@ def test_target_prioritization(mock_frozen):
8799 }
88100
89101 model = AutoModelForCausalLM .from_pretrained (
90- "HuggingFaceM4/tiny-random-LlamaForCausalLM" , torch_dtype = "auto"
102+ "HuggingFaceM4/tiny-random-LlamaForCausalLM" , torch_dtype = "auto" ,
103+ cache_dir = "test-apply-model-cache"
91104 )
92105 model .eval ()
93106
@@ -185,6 +198,7 @@ def get_tinyllama_model():
185198 return AutoModelForCausalLM .from_pretrained (
186199 "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" ,
187200 torch_dtype = "auto" ,
201+ cache_dir = "test-apply-model-cache" ,
188202 )
189203
190204
0 commit comments