@@ -96,10 +96,19 @@ def quantize_model(
9696 precision = get_precision ()
9797
9898 try :
99- # Easier to ask forgiveness than permission
100- quant_handler = ao_quantizer_class_dict [quantizer ](
101- groupsize = q_kwargs ["groupsize" ], device = device , precision = precision
102- )
99+ if quantizer == "linear:a8wxdq" :
100+ quant_handler = ao_quantizer_class_dict [quantizer ](
101+ device = device ,
102+ precision = precision ,
103+ bitwidth = q_kwargs .get ("bitwidth" , 4 ),
104+ groupsize = q_kwargs .get ("groupsize" , 128 ),
105+ has_weight_zeros = q_kwargs .get ("has_weight_zeros" , False ),
106+ )
107+ else :
108+ # Easier to ask forgiveness than permission
109+ quant_handler = ao_quantizer_class_dict [quantizer ](
110+ groupsize = q_kwargs ["groupsize" ], device = device , precision = precision
111+ )
103112 except TypeError as e :
104113 if "unexpected keyword argument 'device'" in str (e ):
105114 quant_handler = ao_quantizer_class_dict [quantizer ](
@@ -861,3 +870,33 @@ def quantized_model(self) -> nn.Module:
861870 "linear:int4" : Int4WeightOnlyQuantizer ,
862871 "linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
863872}
873+
874+ try :
875+ import importlib .util
876+ import sys
877+ import os
878+ torchao_build_path = f"{ os .getcwd ()} /torchao-build"
879+
880+ # Try loading quantizer
881+ torchao_experimental_quant_api_spec = importlib .util .spec_from_file_location (
882+ "torchao_experimental_quant_api" ,
883+ f"{ torchao_build_path } /src/ao/torchao/experimental/quant_api.py" ,
884+ )
885+ torchao_experimental_quant_api = importlib .util .module_from_spec (torchao_experimental_quant_api_spec )
886+ sys .modules ["torchao_experimental_quant_api" ] = torchao_experimental_quant_api
887+ torchao_experimental_quant_api_spec .loader .exec_module (torchao_experimental_quant_api )
888+ from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
889+ ao_quantizer_class_dict ["linear:a8wxdq" ] = Int8DynActIntxWeightQuantizer
890+
891+ # Try loading custom op
892+ try :
893+ import glob
894+ libs = glob .glob (f"{ torchao_build_path } /cmake-out/liblowbit_op_aten.*" )
895+ libs = list (filter (lambda l : (l .endswith ("so" ) or l .endswith ("dylib" )), libs ))
896+ torch .ops .load_library (libs [0 ])
897+ except Exception as e :
898+ print ("Failed to torchao custom op library with error: " , e )
899+ print ("Slow fallback kernels will be used." )
900+
901+ except Exception as e :
902+ print (f"Failed to load torchao experimental a8wxdq quantizer with error: { e } " )
0 commit comments