@@ -21,7 +21,7 @@ def setUp(self):
2121 def test_oneshot_sparsification_then_finetune (self ):
2222 recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml"
2323 model = AutoModelForCausalLM .from_pretrained (
24- "nm-testing/llama2.c-stories15M" , device_map = "auto" , torch_dtype = "auto"
24+ "nm-testing/llama2.c-stories15M" , torch_dtype = "auto"
2525 )
2626 dataset = "open_platypus"
2727 concatenate_data = False
@@ -47,12 +47,11 @@ def test_oneshot_sparsification_then_finetune(self):
4747 # Explictly decompress the model for training using quantization_config
4848 model = AutoModelForCausalLM .from_pretrained (
4949 self .output / "oneshot_out" ,
50- device_map = "auto" ,
5150 torch_dtype = "auto" ,
5251 quantization_config = self .quantization_config ,
5352 )
5453 distill_teacher = AutoModelForCausalLM .from_pretrained (
55- "nm-testing/llama2.c-stories15M" , device_map = "auto" , torch_dtype = "auto"
54+ "nm-testing/llama2.c-stories15M" , torch_dtype = "auto"
5655 )
5756 dataset = "open_platypus"
5857 concatenate_data = False
@@ -88,7 +87,6 @@ def test_oneshot_sparsification_then_finetune(self):
8887 # Explictly decompress the model for training using quantization_config
8988 model = AutoModelForCausalLM .from_pretrained (
9089 output_dir ,
91- device_map = "auto" ,
9290 torch_dtype = "auto" ,
9391 quantization_config = self .quantization_config ,
9492 )
@@ -112,7 +110,7 @@ def test_oneshot_quantization_then_finetune(self):
112110 )
113111
114112 model = AutoModelForCausalLM .from_pretrained (
115- "TinyLlama/TinyLlama-1.1B-Chat-v1.0" , device_map = "auto" , torch_dtype = "auto"
113+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = "auto"
116114 )
117115 dataset = "open_platypus"
118116 concatenate_data = False
@@ -136,7 +134,6 @@ def test_oneshot_quantization_then_finetune(self):
136134 quantization_config = CompressedTensorsConfig (run_compressed = False )
137135 model = AutoModelForCausalLM .from_pretrained (
138136 output_dir ,
139- device_map = "auto" ,
140137 torch_dtype = "auto" ,
141138 quantization_config = quantization_config ,
142139 )
@@ -159,7 +156,6 @@ def test_oneshot_quantization_then_finetune(self):
159156 # test reloading checkpoint and final model
160157 model = AutoModelForCausalLM .from_pretrained (
161158 output_dir ,
162- device_map = "auto" ,
163159 torch_dtype = "auto" ,
164160 quantization_config = quantization_config ,
165161 )
0 commit comments