2626
2727# from functools import reduce
2828# from math import gcd
29- from typing import Dict , Optional , Callable , Any , List
29+ from typing import Any , Callable , Dict , List , Optional
3030
3131import torch
3232import torch .nn as nn
3737from torchao .quantization .quant_api import (
3838 int4_weight_only ,
3939 Int4WeightOnlyQuantizer ,
40+ int8_weight_only ,
4041 Int8DynActInt4WeightQuantizer ,
4142 quantize_ ,
4243)
4546 find_multiple ,
4647 get_device_str ,
4748 get_precision ,
48- set_precision ,
4949 name_to_dtype ,
50+ set_precision ,
5051 state_dict_device ,
5152 use_et_backend ,
5253)
6061
6162import inspect
6263
64+
6365def get_named_parameters (func : Callable ) -> List [str ]:
6466 # Get the signature of the function
6567 signature = inspect .signature (func )
66-
68+
6769 # Extract the parameters from the signature
6870 parameters = signature .parameters
69-
71+
7072 # Filter and return named parameters
7173 named_params = [
72- name for name , param in parameters .items ()
73- if param .kind in (inspect .Parameter .POSITIONAL_OR_KEYWORD , inspect .Parameter .KEYWORD_ONLY )
74+ name
75+ for name , param in parameters .items ()
76+ if param .kind
77+ in (inspect .Parameter .POSITIONAL_OR_KEYWORD , inspect .Parameter .KEYWORD_ONLY )
7478 ]
7579 return named_params
7680
77- def validate_args (named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None ) -> Dict [str , Any ]:
81+
82+ def validate_args (
83+ named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None
84+ ) -> Dict [str , Any ]:
7885 for key in q_kwargs .keys ():
7986 if key not in named_params :
80- print (f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring." )
87+ print (
88+ f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring."
89+ )
8190 del q_kwargs [key ]
8291 return q_kwargs
83-
84-
92+
93+
8594#########################################################################
8695### torchchat quantization API ###
8796
@@ -110,21 +119,30 @@ def quantize_model(
110119 if quantizer not in quantizer_class_dict :
111120 raise RuntimeError (f"unknown quantizer { quantizer } specified" )
112121 else :
122+ ao_quant = True
113123 # Use tensor subclass API for int4 weight only.
114124 if device == "cuda" and quantizer == "linear:int4" :
115125 quantize_ (model , int4_weight_only (q_kwargs ["groupsize" ]))
126+ elif quantizer == "linear:int8" :
127+ print ("quantizer is linear int8" )
128+ quantize_ (model , int8_weight_only ())
129+ else :
130+ ao_quant = False
131+ if ao_quant :
116132 if not support_tensor_subclass :
117133 unwrap_tensor_subclass (model )
118134 continue
119-
135+
120136 if quantizer in ["linear:a8wxdq" , "embedding:wx" ]:
121137 # These quantizers require float32 input weights. Note that after quantization,
122138 # the weights will no longer be float32, but lowbit integers
123139 if get_precision () != torch .float32 :
124- print (f"Quantizer { quantizer } requires float32 inputs, but received { get_precision ()} . Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." )
140+ print (
141+ f"Quantizer { quantizer } requires float32 inputs, but received { get_precision ()} . Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32."
142+ )
125143 set_precision (torch .float32 )
126-
127- # We set global precision from quantize options if it is specified at cli.py:485
144+
145+ # We set global precision from quantize options if it is specified at cli.py:485
128146 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
129147 precision = get_precision ()
130148
@@ -141,14 +159,19 @@ def quantize_model(
141159 model = quant_handler .quantize (model )
142160
143161
144-
145162#########################################################################
146163### QuantHandler API definition ###
147164### (unify with torchao in future) ###
148165
149166
150167class QuantHandler :
151- def __init__ (self , model : Optional [nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None ):
168+ def __init__ (
169+ self ,
170+ model : Optional [nn .Module ] = None ,
171+ device = "cpu" ,
172+ precision = None ,
173+ tokenizer = None ,
174+ ):
152175 self .model_ = model
153176 self .device = device
154177 self .tokenizer = tokenizer
@@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module:
176199
177200
178201class PrecisionHandler (QuantHandler ):
179- def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None , tokenizer = None , * , dtype ):
202+ def __init__ (
203+ self ,
204+ model : Optional [nn .Module ] = None ,
205+ device = "cpu" ,
206+ precision = None ,
207+ tokenizer = None ,
208+ * ,
209+ dtype ,
210+ ):
180211 self .model_ = model
181212 self .device = device
182213 self .tokenizer = tokenizer
@@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module:
205236
206237
207238class ExecutorHandler (QuantHandler ):
208- def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None , tokenizer = None , * , accelerator ):
239+ def __init__ (
240+ self ,
241+ model : Optional [nn .Module ] = None ,
242+ device = "cpu" ,
243+ precision = None ,
244+ tokenizer = None ,
245+ * ,
246+ accelerator ,
247+ ):
209248 self .model_ = model
210249
211250 if isinstance (accelerator , str ):
@@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales):
529568 )
530569
531570
532- class WeightOnlyInt8Linear (nn .Module ):
533- __constants__ = ["in_features" , "out_features" ]
534- in_features : int
535- out_features : int
536- weight : torch .Tensor
537- scales : torch .Tensor
538-
539- def __init__ (
540- self ,
541- in_features ,
542- out_features ,
543- bias = None ,
544- device = None ,
545- dtype = None ,
546- * ,
547- weight : Optional [torch .Tensor ] = None ,
548- scales : Optional [torch .Tensor ] = None ,
549- groupsize : Optional [int ] = None ,
550- ):
551- super ().__init__ ()
552- if dtype is None :
553- dtype = torch .get_default_dtype ()
554-
555- if device is None :
556- device = "cpu"
557-
558- assert not bias , "Bias is not supported by LinearInt8"
559- self .in_features = in_features
560- self .out_features = out_features
561-
562- assert (weight is None ) == bool (
563- scales is None
564- ), "must specify both weights and scales, or neither"
565- if weight is None :
566- weight = torch .empty (
567- (out_features , in_features ),
568- dtype = torch .int8 ,
569- device = device ,
570- )
571- if groupsize is None or (groupsize == 0 ):
572- scales = torch .empty (out_features , dtype = dtype , device = device )
573- else :
574- n_groups = (in_features + groupsize - 1 ) // groupsize
575- scales = torch .empty (out_features , n_groups , dtype = dtype , device = device )
576-
577- self .register_buffer ("weight" , weight .to (device ))
578- self .register_buffer ("scales" , scales .to (device ))
579-
580- if use_et_backend ():
581- self .forward = self .et_forward
582- else :
583- self .forward = self .aoti_forward
584-
585- def aoti_forward (self , input : torch .Tensor ) -> torch .Tensor :
586- return linear_int8_aoti (input , self .weight , self .scales )
587-
588- def et_forward (self , input : torch .Tensor ) -> torch .Tensor :
589- return linear_int8_et (input , self .weight , self .scales )
590-
591-
592- class WeightOnlyInt8QuantHandler (QuantHandler ):
593- def __init__ (
594- self ,
595- model : Optional [nn .Module ] = None ,
596- device = None ,
597- precision = None ,
598- tokenizer = None ,
599- * ,
600- node_type : str = "*" ,
601- bitwidth : Optional [int ] = None ,
602- groupsize : Optional [int ] = None ,
603- ):
604- self .model_ = model
605- self .device = device
606- self .groupsize = groupsize
607- self .node_type = node_type
608- if bitwidth is None :
609- self .bitwidth = 8
610- else :
611- self .bitwidth = bitwidth
612-
613- @torch .no_grad ()
614- def quantize (self , module ):
615- # cur_state_dict = state_dict_device(self.model_.state_dict())
616- # dict_device = "cpu" # self.device
617-
618- if self .bitwidth == 4 :
619- range_min = - 8
620- range_max = 7
621- elif self .bitwidth == 8 :
622- range_min = - 128
623- range_max = 127
624- else :
625- raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
626-
627- for name , child in module .named_children ():
628- # print(f"name: {name}")
629- if isinstance (child , nn .Linear ):
630- if (
631- (self .node_type == "*" )
632- or (self .node_type == "output" and name == "output" )
633- or (self .node_type == "!output" and name != "output" )
634- ):
635- # print(f"{name, child}")
636- input_weight = child .weight .float ()
637- # print(f"{name, child}")
638- # print(f"in_features: {child.in_features}")
639- # print(f"out_features: {child.out_features}")
640-
641- # print(f"expanded weight shape {input_weight.shape}")
642- weight , scales , _ = dynamically_quantize_per_channel (
643- input_weight ,
644- range_min ,
645- range_max ,
646- torch .int8 ,
647- self .groupsize ,
648- scales_dtype = child .weight .dtype ,
649- )
650-
651- setattr (
652- module ,
653- name ,
654- WeightOnlyInt8Linear (
655- in_features = child .in_features ,
656- out_features = child .out_features ,
657- device = self .device ,
658- # update variables from quantization
659- weight = weight ,
660- scales = scales ,
661- groupsize = self .groupsize ,
662- ),
663- )
664- else :
665- self .quantize (child )
666-
667- return module
668-
669- def quantized_model (self ) -> nn .Module :
670- return self .quantize (self .model_ )
671-
672-
673571#########################################################################
674572##### embedding table quantization ######
675573### (unify with torchao in future) ###
@@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module:
886784# class references
887785quantizer_class_dict = {
888786 "embedding" : EmbeddingOnlyQuantHandler ,
889- "linear:int8" : WeightOnlyInt8QuantHandler ,
890787 "precision" : PrecisionHandler ,
891788 "executor" : ExecutorHandler ,
892789 "linear:int4" : Int4WeightOnlyQuantizer ,
790+ "linear:int8" : int8_weight_only ,
893791 "linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
894792}
895793
@@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module:
932830 print ("Slow fallback kernels will be used." )
933831
934832except Exception as e :
833+
935834 class ErrorHandler (QuantHandler ):
936- def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None ):
835+ def __init__ (
836+ self , model : Optional [nn .Module ] = None , device = "cpu" , precision = None
837+ ):
937838 global torchao_experimental_load_error
938- raise Exception (f"Note: Failed to load torchao experimental quantizer with error: { torchao_experimental_load_error } " )
939-
839+ raise Exception (
840+ f"Note: Failed to load torchao experimental quantizer with error: { torchao_experimental_load_error } "
841+ )
842+
940843 torchao_experimental_load_error = e
941844 quantizer_class_dict ["linear:a8wxdq" ] = ErrorHandler
942845 quantizer_class_dict ["embedding:wx" ] = ErrorHandler
0 commit comments