File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -526,7 +526,7 @@ def test_moving_to_cpu_throws_warning(self):
526526        reason = "Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release." , 
527527        strict = True , 
528528    ) 
529-     def  test_pipeline_device_placement_works_with_nf4 (self ):
529+     def  test_pipeline_cuda_placement_works_with_nf4 (self ):
530530        transformer_nf4_config  =  BitsAndBytesConfig (
531531            load_in_4bit = True ,
532532            bnb_4bit_quant_type = "nf4" ,
@@ -560,7 +560,7 @@ def test_pipeline_device_placement_works_with_nf4(self):
560560        ).to (torch_device )
561561
562562        # Check if inference works. 
563-         _  =  pipeline_4bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
563+         _  =  pipeline_4bit (self . prompt , max_sequence_length = 20 , num_inference_steps = 2 )
564564
565565        del  pipeline_4bit 
566566
Original file line number Diff line number Diff line change @@ -492,7 +492,7 @@ def test_generate_quality_dequantize(self):
492492        self .assertTrue (max_diff  <  1e-2 )
493493
494494        # 8bit models cannot be offloaded to CPU. 
495-         self .assertTrue (self .pipeline_8bit .transformer .device .type  ==  "cuda" )
495+         self .assertTrue (self .pipeline_8bit .transformer .device .type  ==  torch_device )
496496        # calling it again shouldn't be a problem 
497497        _  =  self .pipeline_8bit (
498498            prompt = self .prompt ,
@@ -534,7 +534,7 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self):
534534        ).to (device )
535535
536536        # Check if inference works. 
537-         _  =  pipeline_8bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
537+         _  =  pipeline_8bit (self . prompt , max_sequence_length = 20 , num_inference_steps = 2 )
538538
539539        del  pipeline_8bit 
540540
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments