diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 12551a69e..8a708d416 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -26,7 +26,7 @@ # from functools import reduce # from math import gcd -from typing import Dict, Optional +from typing import Dict, Optional, Callable, Any, List import torch import torch.nn as nn @@ -54,6 +54,33 @@ # Flag for whether the a8wxdq quantizer is available. a8wxdq_load_error: Optional[Exception] = None +######################################################################### +### handle arg validation ### + +import inspect + +def get_named_parameters(func: Callable) -> List[str]: + # Get the signature of the function + signature = inspect.signature(func) + + # Extract the parameters from the signature + parameters = signature.parameters + + # Filter and return named parameters + named_params = [ + name for name, param in parameters.items() + if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + ] + return named_params + +def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: + for key in q_kwargs.keys(): + if key not in named_params: + print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") + del q_kwargs[key] + return q_kwargs + + ######################################################################### ### torchchat quantization API ### @@ -79,56 +106,32 @@ def quantize_model( quantize_options = json.loads(quantize_options) for quantizer, q_kwargs in quantize_options.items(): - # Test if a8wxdq quantizer is available; Surface error if not. - if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None: - raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}") - - if ( - quantizer not in quantizer_class_dict - and quantizer not in ao_quantizer_class_dict - ): + if quantizer not in quantizer_class_dict: raise RuntimeError(f"unknown quantizer {quantizer} specified") - if quantizer in ao_quantizer_class_dict: + else: # Use tensor subclass API for int4 weight only. if device == "cuda" and quantizer == "linear:int4": quantize_(model, int4_weight_only(q_kwargs["groupsize"])) if not support_tensor_subclass: unwrap_tensor_subclass(model) continue + # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat precision = get_precision() - try: - if quantizer == "linear:a8wxdq": - quant_handler = ao_quantizer_class_dict[quantizer]( - device=device, - precision=precision, - bitwidth=q_kwargs.get("bitwidth", 4), - groupsize=q_kwargs.get("groupsize", 128), - has_weight_zeros=q_kwargs.get("has_weight_zeros", False), - ) - else: - # Easier to ask forgiveness than permission - quant_handler = ao_quantizer_class_dict[quantizer]( - groupsize=q_kwargs["groupsize"], device=device, precision=precision - ) - except TypeError as e: - if "unexpected keyword argument 'device'" in str(e): - quant_handler = ao_quantizer_class_dict[quantizer]( - groupsize=q_kwargs["groupsize"], precision=precision - ) - elif "unexpected keyword argument 'precision'" in str(e): - quant_handler = ao_quantizer_class_dict[quantizer]( - groupsize=q_kwargs["groupsize"], device=device - ) - else: - raise e + q = quantizer_class_dict[quantizer] + named_params = get_named_parameters(q.__init__) + q_kwargs = validate_args(named_params, q_kwargs, quantizer) + + # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs + if "tokenizer" in named_params: + q_kwargs["tokenizer"] = tokenizer + quant_handler = q(device=device, precision=precision, **q_kwargs) + + # quantize model model = quant_handler.quantize(model) - else: - model = quantizer_class_dict[quantizer]( - model, device=device, tokenizer=tokenizer, **q_kwargs - ).quantized_model() + ######################################################################### @@ -137,7 +140,7 @@ def quantize_model( class QuantHandler: - def __init__(self, model: nn.Module, device="cpu", tokenizer=None): + def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -154,13 +157,18 @@ def quantized_model(self) -> nn.Module: self.model_.load_state_dict(model_updated_state_dict) return self.model_ + # fallback for TC QuantHandlers that do not implement the method .quantize() + def quantize(self, model: nn.Module) -> nn.Module: + self.model_ = model + return self.quantized_model() + ######################################################################### ### wrapper for setting precision as a QuantHandler ### class PrecisionHandler(QuantHandler): - def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype): + def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -169,6 +177,9 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype): dtype = name_to_dtype(dtype, device) self.dtype = dtype + # We simply ignore precision. because dtype is the precision arg as possibly string + # maybe: assert(precision in [self.dtype, None]) + def create_quantized_state_dict(self) -> Dict: # "StateDict" pass @@ -186,7 +197,7 @@ def quantized_model(self) -> nn.Module: class ExecutorHandler(QuantHandler): - def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, accelerator): + def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator): self.model_ = model if isinstance(accelerator, str): @@ -573,8 +584,9 @@ def et_forward(self, input: torch.Tensor) -> torch.Tensor: class WeightOnlyInt8QuantHandler(QuantHandler): def __init__( self, - model: nn.Module, - device, + model: Optional[nn.Module] = None, + device = None, + precision=None, tokenizer=None, *, node_type: str = "*", @@ -774,8 +786,9 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor: class EmbeddingOnlyQuantHandler(QuantHandler): def __init__( self, - model: nn.Module, - device, + model: Optional[nn.Module] = None, + device=None, + precision=None, tokenizer=None, *, bitwidth: int = 8, @@ -868,9 +881,6 @@ def quantized_model(self) -> nn.Module: "linear:int8": WeightOnlyInt8QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, -} - -ao_quantizer_class_dict = { "linear:int4": Int4WeightOnlyQuantizer, "linear:a8w4dq": Int8DynActInt4WeightQuantizer, } @@ -890,7 +900,7 @@ def quantized_model(self) -> nn.Module: sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api) from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer - ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer + quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer # Try loading custom op try: @@ -903,4 +913,10 @@ def quantized_model(self) -> nn.Module: print("Slow fallback kernels will be used.") except Exception as e: + class ErrorHandler(QuantHandler): + def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None): + global a8wxdq_load_error + raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}") + a8wxdq_load_error = e + quantizer_class_dict["linear:a8wxdq"] = ErrorHandler