@@ -61,6 +61,8 @@ class BuilderArgs:
6161    dynamic_shapes : bool  =  False 
6262    max_seq_length : Optional [int ] =  None 
6363
64+     quantized_state_path : Optional [Union [Path , str ]] =  None 
65+ 
6466    def  __post_init__ (self ):
6567        if  self .device  is  None :
6668            self .device  =  "cuda"  if  torch .cuda .is_available () else  "cpu" 
@@ -171,6 +173,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
171173            is_chat_model = is_chat_model ,
172174            dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
173175            max_seq_length = getattr (args , "max_seq_length" , None ),
176+             quantized_state_path = args .quantized_state_path ,
174177        )
175178
176179    @classmethod  
@@ -565,25 +568,43 @@ def _initialize_model(
565568            model  =  _load_model (builder_args )
566569            device_sync (device = builder_args .device )
567570
568-         if  quantize :
569-             print (f"Quantizing the model with: { quantize }  " )
570-             with  measure_time ("Time to quantize model: {time:.02f} seconds" ):
571-                 quantize_model (
572-                     model ,
573-                     builder_args .device ,
574-                     quantize ,
575-                     tokenizer ,
576-                     support_tensor_subclass ,
577-                 )
578-                 device_sync (device = builder_args .device )
571+         cache_path  =  builder_args .quantized_state_path 
572+         quant_checkpoint_exists : bool  =  os .path .isfile (cache_path )
573+         if  quantize  or  quant_checkpoint_exists :
579574
580-         if  builder_args .setup_caches :
581-             with  torch .device (builder_args .device ):
582-                 model .setup_caches (
583-                     max_batch_size = 1 ,
584-                     max_seq_length = max_seq_length 
585-                     or  model .text_transformer_args .max_seq_length ,
586-                 )
575+             if  quantize  and  quant_checkpoint_exists :
576+                 print ("WARNING: Both a quantized checkpoint and quantize arg were provided; Ignoring quantize arg" )
577+ 
578+             if  quant_checkpoint_exists :
579+                 with  measure_time ("Time to load quantized state: {time:.02f} seconds" ):
580+                     print (f"Loading the model_state in: { cache_path }  " )
581+                     model .load_state_dict (cache_path )
582+                     device_sync (device = builder_args .device )
583+             else :
584+                 with  measure_time ("Time to quantize model: {time:.02f} seconds" ):
585+                     print (f"Quantizing the model with: { quantize }  " )
586+                     quantize_model (
587+                         model ,
588+                         builder_args .device ,
589+                         quantize ,
590+                         tokenizer ,
591+                         support_tensor_subclass ,
592+                     )
593+                     device_sync (device = builder_args .device )
594+ 
595+                 if  cache_path :
596+                     with  measure_time ("Time to save quantized state: {time:.02f} seconds" ):
597+                         print (f"Saving the quantized state dict" )
598+                         torch .save (model .state_dict (), cache_path )
599+ 
600+ 
601+             if  builder_args .setup_caches :
602+                 with  torch .device (builder_args .device ):
603+                     model .setup_caches (
604+                         max_batch_size = 1 ,
605+                         max_seq_length = max_seq_length 
606+                         or  model .text_transformer_args .max_seq_length ,
607+                     )
587608
588609        model .to (dtype = builder_args .precision )
589610
0 commit comments