Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
113 changes: 65 additions & 48 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# from functools import reduce
# from math import gcd
from typing import Dict, Optional
from typing import Dict, Optional, Callable, Any, List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -54,6 +54,34 @@
# Flag for whether the a8wxdq quantizer is available.
a8wxdq_load_error: Optional[Exception] = None

#########################################################################
### handle arg validation ###

import inspect

def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function
signature = inspect.signature(func)

# Extract the parameters from the signature
parameters = signature.parameters

# Filter and return named parameters
named_params = [
name for name, param in parameters.items()
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
]
return named_params

def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:
named_params =
for key in q_kwargs.keys():
if key not in q_kwargs:
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
del q_kwargs[key]
return q_kwargs


#########################################################################
### torchchat quantization API ###

Expand All @@ -79,56 +107,32 @@ def quantize_model(
quantize_options = json.loads(quantize_options)

for quantizer, q_kwargs in quantize_options.items():
# Test if a8wxdq quantizer is available; Surface error if not.
if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None:
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")

if (
quantizer not in quantizer_class_dict
and quantizer not in ao_quantizer_class_dict
):
if quantizer not in quantizer_class_dict:
raise RuntimeError(f"unknown quantizer {quantizer} specified")
if quantizer in ao_quantizer_class_dict:
else:
# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue

# We set global precision from quantize options if it is specified at cli.py:485
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
precision = get_precision()

try:
if quantizer == "linear:a8wxdq":
quant_handler = ao_quantizer_class_dict[quantizer](
device=device,
precision=precision,
bitwidth=q_kwargs.get("bitwidth", 4),
groupsize=q_kwargs.get("groupsize", 128),
has_weight_zeros=q_kwargs.get("has_weight_zeros", False),
)
else:
# Easier to ask forgiveness than permission
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], device=device, precision=precision
)
except TypeError as e:
if "unexpected keyword argument 'device'" in str(e):
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], precision=precision
)
elif "unexpected keyword argument 'precision'" in str(e):
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], device=device
)
else:
raise e
q = quantizer_class_dict[quantizer]
named_params = get_named_params(q.__init__)
q_kwargs = validate_args(named_params, quantizer, q_kwargs)

# Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
if "tokenizer" in named_params:
q_kwargs["tokenizer"] = tokenizer
quant_handler = q(device=device, precision=precision, **q_kwargs)

# quantize model
model = quant_handler.quantize(model)
else:
model = quantizer_class_dict[quantizer](
model, device=device, tokenizer=tokenizer, **q_kwargs
).quantized_model()



#########################################################################
Expand All @@ -137,7 +141,7 @@ def quantize_model(


class QuantHandler:
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand All @@ -154,13 +158,18 @@ def quantized_model(self) -> nn.Module:
self.model_.load_state_dict(model_updated_state_dict)
return self.model_

# fallback for TC QuantHandlers that do not implement the method .quantize()
def quantize(self, model: nn.Module) -> nn.Module:
self.model_ = model
return self.quantized_model()


#########################################################################
### wrapper for setting precision as a QuantHandler ###


class PrecisionHandler(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand All @@ -169,6 +178,9 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
dtype = name_to_dtype(dtype, device)
self.dtype = dtype

# We simply ignore precision. because dtype is the precision arg as possibly string
# maybe: assert(precision in [self.dtype, None])

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass

Expand All @@ -186,7 +198,7 @@ def quantized_model(self) -> nn.Module:


class ExecutorHandler(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, accelerator):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
self.model_ = model

if isinstance(accelerator, str):
Expand Down Expand Up @@ -573,8 +585,9 @@ def et_forward(self, input: torch.Tensor) -> torch.Tensor:
class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
model: nn.Module,
model: Optional[nn.Module] = None,
device,
precision=None,
tokenizer=None,
*,
node_type: str = "*",
Expand Down Expand Up @@ -774,8 +787,9 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
class EmbeddingOnlyQuantHandler(QuantHandler):
def __init__(
self,
model: nn.Module,
model: Optional[nn.Module] = None,
device,
precision=None,
tokenizer=None,
*,
bitwidth: int = 8,
Expand Down Expand Up @@ -868,9 +882,6 @@ def quantized_model(self) -> nn.Module:
"linear:int8": WeightOnlyInt8QuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
}

ao_quantizer_class_dict = {
"linear:int4": Int4WeightOnlyQuantizer,
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}
Expand All @@ -890,7 +901,7 @@ def quantized_model(self) -> nn.Module:
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we can't put this inside of the quantizer_class_dict?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems cleaner to try import Int8DynActIntxWeightQuantizer first and fallback to ErrorHandler, then assign this handler to linear:a8wxdq in quantizer_class_dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems cleaner to try import Int8DynActIntxWeightQuantizer first and fallback to ErrorHandler, then assign this handler to linear:a8wxdq in quantizer_class_dict.

This is what the code is doing now. Try to import, and if the import fails, set up the error handler.

Or were you thinking to do a wrapper that imports the class Int8DynActIntxWeightQuantizer and then calls the error if it fails, and the imported method if the import succeeds? I assumed that all the conditional import will go away soonish since we should know what version of AO we pin, and whether it has the new class. (And that enablement is there, it doesn't disappear again, so that we can just do a simple import in the future.)

.... Because init can't return an alternate class, the wrapper would have to redispatch all methods internally, which isn't a big deal per se, but may add readability concerns? I'm happy to go either way, ideally as a follow-on.

LMK what you think the best long-term trajectory is for this functionality, and we'll align the code to that. (I think the current version is preferable to the previous state, b/c we don't have to special case in the dispatch loop)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if this can go away soon I don't have a strong opinion and can live with this.


# Try loading custom op
try:
Expand All @@ -903,4 +914,10 @@ def quantized_model(self) -> nn.Module:
print("Slow fallback kernels will be used.")

except Exception as e:
class ErrorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
global a8wxdq_load_error
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")

a8wxdq_load_error = e
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
Loading