2626
2727# from functools import reduce
2828# from math import gcd
29- from typing import Dict , Optional
29+ from typing import Dict , Optional , Callable , Any , List
3030
3131import torch
3232import torch .nn as nn
5454# Flag for whether the a8wxdq quantizer is available.
5555a8wxdq_load_error : Optional [Exception ] = None
5656
57+ #########################################################################
58+ ### handle arg validation ###
59+
60+ import inspect
61+
62+ def get_named_parameters (func : Callable ) -> List [str ]:
63+ # Get the signature of the function
64+ signature = inspect .signature (func )
65+
66+ # Extract the parameters from the signature
67+ parameters = signature .parameters
68+
69+ # Filter and return named parameters
70+ named_params = [
71+ name for name , param in parameters .items ()
72+ if param .kind in (inspect .Parameter .POSITIONAL_OR_KEYWORD , inspect .Parameter .KEYWORD_ONLY )
73+ ]
74+ return named_params
75+
76+ def validate_args (named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None ) -> Dict [str , Any ]:
77+ for key in q_kwargs .keys ():
78+ if key not in named_params :
79+ print (f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring." )
80+ del q_kwargs [key ]
81+ return q_kwargs
82+
83+
5784#########################################################################
5885### torchchat quantization API ###
5986
@@ -79,56 +106,32 @@ def quantize_model(
79106 quantize_options = json .loads (quantize_options )
80107
81108 for quantizer , q_kwargs in quantize_options .items ():
82- # Test if a8wxdq quantizer is available; Surface error if not.
83- if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None :
84- raise Exception (f"Note: Failed to load torchao experimental a8wxdq quantizer with error: { a8wxdq_load_error } " )
85-
86- if (
87- quantizer not in quantizer_class_dict
88- and quantizer not in ao_quantizer_class_dict
89- ):
109+ if quantizer not in quantizer_class_dict :
90110 raise RuntimeError (f"unknown quantizer { quantizer } specified" )
91- if quantizer in ao_quantizer_class_dict :
111+ else :
92112 # Use tensor subclass API for int4 weight only.
93113 if device == "cuda" and quantizer == "linear:int4" :
94114 quantize_ (model , int4_weight_only (q_kwargs ["groupsize" ]))
95115 if not support_tensor_subclass :
96116 unwrap_tensor_subclass (model )
97117 continue
118+
98119 # We set global precision from quantize options if it is specified at cli.py:485
99120 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
100121 precision = get_precision ()
101122
102- try :
103- if quantizer == "linear:a8wxdq" :
104- quant_handler = ao_quantizer_class_dict [quantizer ](
105- device = device ,
106- precision = precision ,
107- bitwidth = q_kwargs .get ("bitwidth" , 4 ),
108- groupsize = q_kwargs .get ("groupsize" , 128 ),
109- has_weight_zeros = q_kwargs .get ("has_weight_zeros" , False ),
110- )
111- else :
112- # Easier to ask forgiveness than permission
113- quant_handler = ao_quantizer_class_dict [quantizer ](
114- groupsize = q_kwargs ["groupsize" ], device = device , precision = precision
115- )
116- except TypeError as e :
117- if "unexpected keyword argument 'device'" in str (e ):
118- quant_handler = ao_quantizer_class_dict [quantizer ](
119- groupsize = q_kwargs ["groupsize" ], precision = precision
120- )
121- elif "unexpected keyword argument 'precision'" in str (e ):
122- quant_handler = ao_quantizer_class_dict [quantizer ](
123- groupsize = q_kwargs ["groupsize" ], device = device
124- )
125- else :
126- raise e
123+ q = quantizer_class_dict [quantizer ]
124+ named_params = get_named_parameters (q .__init__ )
125+ q_kwargs = validate_args (named_params , q_kwargs , quantizer )
126+
127+ # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
128+ if "tokenizer" in named_params :
129+ q_kwargs ["tokenizer" ] = tokenizer
130+ quant_handler = q (device = device , precision = precision , ** q_kwargs )
131+
132+ # quantize model
127133 model = quant_handler .quantize (model )
128- else :
129- model = quantizer_class_dict [quantizer ](
130- model , device = device , tokenizer = tokenizer , ** q_kwargs
131- ).quantized_model ()
134+
132135
133136
134137#########################################################################
@@ -137,7 +140,7 @@ def quantize_model(
137140
138141
139142class QuantHandler :
140- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None ):
143+ def __init__ (self , model : Optional [ nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None ):
141144 self .model_ = model
142145 self .device = device
143146 self .tokenizer = tokenizer
@@ -154,13 +157,18 @@ def quantized_model(self) -> nn.Module:
154157 self .model_ .load_state_dict (model_updated_state_dict )
155158 return self .model_
156159
160+ # fallback for TC QuantHandlers that do not implement the method .quantize()
161+ def quantize (self , model : nn .Module ) -> nn .Module :
162+ self .model_ = model
163+ return self .quantized_model ()
164+
157165
158166#########################################################################
159167### wrapper for setting precision as a QuantHandler ###
160168
161169
162170class PrecisionHandler (QuantHandler ):
163- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None , * , dtype ):
171+ def __init__ (self , model : Optional [ nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None , * , dtype ):
164172 self .model_ = model
165173 self .device = device
166174 self .tokenizer = tokenizer
@@ -169,6 +177,9 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
169177 dtype = name_to_dtype (dtype , device )
170178 self .dtype = dtype
171179
180+ # We simply ignore precision. because dtype is the precision arg as possibly string
181+ # maybe: assert(precision in [self.dtype, None])
182+
172183 def create_quantized_state_dict (self ) -> Dict : # "StateDict"
173184 pass
174185
@@ -186,7 +197,7 @@ def quantized_model(self) -> nn.Module:
186197
187198
188199class ExecutorHandler (QuantHandler ):
189- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None , * , accelerator ):
200+ def __init__ (self , model : Optional [ nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None , * , accelerator ):
190201 self .model_ = model
191202
192203 if isinstance (accelerator , str ):
@@ -573,8 +584,9 @@ def et_forward(self, input: torch.Tensor) -> torch.Tensor:
573584class WeightOnlyInt8QuantHandler (QuantHandler ):
574585 def __init__ (
575586 self ,
576- model : nn .Module ,
577- device ,
587+ model : Optional [nn .Module ] = None ,
588+ device = None ,
589+ precision = None ,
578590 tokenizer = None ,
579591 * ,
580592 node_type : str = "*" ,
@@ -774,8 +786,9 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
774786class EmbeddingOnlyQuantHandler (QuantHandler ):
775787 def __init__ (
776788 self ,
777- model : nn .Module ,
778- device ,
789+ model : Optional [nn .Module ] = None ,
790+ device = None ,
791+ precision = None ,
779792 tokenizer = None ,
780793 * ,
781794 bitwidth : int = 8 ,
@@ -868,9 +881,6 @@ def quantized_model(self) -> nn.Module:
868881 "linear:int8" : WeightOnlyInt8QuantHandler ,
869882 "precision" : PrecisionHandler ,
870883 "executor" : ExecutorHandler ,
871- }
872-
873- ao_quantizer_class_dict = {
874884 "linear:int4" : Int4WeightOnlyQuantizer ,
875885 "linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
876886}
@@ -890,7 +900,7 @@ def quantized_model(self) -> nn.Module:
890900 sys .modules ["torchao_experimental_quant_api" ] = torchao_experimental_quant_api
891901 torchao_experimental_quant_api_spec .loader .exec_module (torchao_experimental_quant_api )
892902 from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
893- ao_quantizer_class_dict ["linear:a8wxdq" ] = Int8DynActIntxWeightQuantizer
903+ quantizer_class_dict ["linear:a8wxdq" ] = Int8DynActIntxWeightQuantizer
894904
895905 # Try loading custom op
896906 try :
@@ -903,4 +913,10 @@ def quantized_model(self) -> nn.Module:
903913 print ("Slow fallback kernels will be used." )
904914
905915except Exception as e :
916+ class ErrorHandler (QuantHandler ):
917+ def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None ):
918+ global a8wxdq_load_error
919+ raise Exception (f"Note: Failed to load torchao experimental a8wxdq quantizer with error: { a8wxdq_load_error } " )
920+
906921 a8wxdq_load_error = e
922+ quantizer_class_dict ["linear:a8wxdq" ] = ErrorHandler
0 commit comments