@@ -142,6 +142,11 @@ def quantize_model(
142142 )
143143 set_precision (torch .float32 )
144144
145+ if quantizer == "linear:afpwx" and device != "mps" :
146+ raise RuntimeError (
147+ "linear:afpwx quantization can only run on mps device!"
148+ )
149+
145150 # We set global precision from quantize options if it is specified at cli.py:485
146151 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
147152 precision = get_precision ()
@@ -813,10 +818,12 @@ def quantized_model(self) -> nn.Module:
813818 from torchao_experimental_quant_api import (
814819 Int8DynActIntxWeightLinearQuantizer ,
815820 IntxWeightEmbeddingQuantizer ,
821+ UIntxWeightOnlyLinearQuantizer ,
816822 )
817823
818824 quantizer_class_dict ["linear:a8wxdq" ] = Int8DynActIntxWeightLinearQuantizer
819825 quantizer_class_dict ["embedding:wx" ] = IntxWeightEmbeddingQuantizer
826+ quantizer_class_dict ["linear:afpwx" ] = UIntxWeightOnlyLinearQuantizer
820827
821828 # Try loading custom op
822829 try :
@@ -826,20 +833,16 @@ def quantized_model(self) -> nn.Module:
826833 libs = list (filter (lambda l : (l .endswith ("so" ) or l .endswith ("dylib" )), libs ))
827834 torch .ops .load_library (libs [0 ])
828835 except Exception as e :
829- print ("Failed to torchao ops library with error: " , e )
830- print ("Slow fallback kernels will be used." )
831-
832- except Exception as e :
836+ print (
837+ "Unabled to load torchao cpu ops library. Slow fallback kernels will be used."
838+ )
833839
834- class ErrorHandler (QuantHandler ):
835- def __init__ (
836- self , model : Optional [nn .Module ] = None , device = "cpu" , precision = None
837- ):
838- global torchao_experimental_load_error
839- raise Exception (
840- f"Note: Failed to load torchao experimental quantizer with error: { torchao_experimental_load_error } "
841- )
840+ try :
841+ libname = "libtorchao_ops_mps_aten.dylib"
842+ libpath = f"{ torchao_build_path } /cmake-out/lib/{ libname } "
843+ torch .ops .load_library (libpath )
844+ except Exception as e :
845+ print ("Unabled to load torchao mps ops library." )
842846
843- torchao_experimental_load_error = e
844- quantizer_class_dict ["linear:a8wxdq" ] = ErrorHandler
845- quantizer_class_dict ["embedding:wx" ] = ErrorHandler
847+ except Exception as e :
848+ print ("Unabled to import torchao experimental quant_api with error: " , e )
0 commit comments