Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8b66b1b

Browse files
authored
Align AO and CHAT quantizers in quantize.py
Implement the AO API in torchchat quantization handlers and unify logic. 1 - implement .quantize() for TC quantization handlers and support args to make consistent with AO 2 - remove special handling for various combinations of parameters and use validate_args before calling with **q_kwargs 3 - remove check probing whether we successfully loaded a8wx and install an error-reporting handler if loading failed which will be called as quant handler and issue an error 4 - unify both tc and ao quantization handler dicts with shared calling logic
1 parent 397967f commit 8b66b1b

File tree

1 file changed

+64
-48
lines changed

1 file changed

+64
-48
lines changed

torchchat/utils/quantize.py

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# from functools import reduce
2828
# from math import gcd
29-
from typing import Dict, Optional
29+
from typing import Dict, Optional, Callable, Any, List
3030

3131
import torch
3232
import torch.nn as nn
@@ -54,6 +54,34 @@
5454
# Flag for whether the a8wxdq quantizer is available.
5555
a8wxdq_load_error: Optional[Exception] = None
5656

57+
#########################################################################
58+
### handle arg validation ###
59+
60+
import inspect
61+
62+
def get_named_parameters(func: Callable) -> List[str]:
63+
# Get the signature of the function
64+
signature = inspect.signature(func)
65+
66+
# Extract the parameters from the signature
67+
parameters = signature.parameters
68+
69+
# Filter and return named parameters
70+
named_params = [
71+
name for name, param in parameters.items()
72+
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
73+
]
74+
return named_params
75+
76+
def validate_args(named_params: List[str], quantizer: Optional[str] = None, q_kwargs: Dict[str, Any]) -> Dict:
77+
named_params =
78+
for key in q_kwargs.keys():
79+
if key not in q_kwargs:
80+
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
81+
del q_kwargs[key]
82+
return q_kwargs
83+
84+
5785
#########################################################################
5886
### torchchat quantization API ###
5987

@@ -79,56 +107,32 @@ def quantize_model(
79107
quantize_options = json.loads(quantize_options)
80108

81109
for quantizer, q_kwargs in quantize_options.items():
82-
# Test if a8wxdq quantizer is available; Surface error if not.
83-
if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None:
84-
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")
85-
86-
if (
87-
quantizer not in quantizer_class_dict
88-
and quantizer not in ao_quantizer_class_dict
89-
):
110+
if quantizer not in quantizer_class_dict:
90111
raise RuntimeError(f"unknown quantizer {quantizer} specified")
91-
if quantizer in ao_quantizer_class_dict:
112+
else:
92113
# Use tensor subclass API for int4 weight only.
93114
if device == "cuda" and quantizer == "linear:int4":
94115
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
95116
if not support_tensor_subclass:
96117
unwrap_tensor_subclass(model)
97118
continue
119+
98120
# We set global precision from quantize options if it is specified at cli.py:485
99121
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
100122
precision = get_precision()
101123

102-
try:
103-
if quantizer == "linear:a8wxdq":
104-
quant_handler = ao_quantizer_class_dict[quantizer](
105-
device=device,
106-
precision=precision,
107-
bitwidth=q_kwargs.get("bitwidth", 4),
108-
groupsize=q_kwargs.get("groupsize", 128),
109-
has_weight_zeros=q_kwargs.get("has_weight_zeros", False),
110-
)
111-
else:
112-
# Easier to ask forgiveness than permission
113-
quant_handler = ao_quantizer_class_dict[quantizer](
114-
groupsize=q_kwargs["groupsize"], device=device, precision=precision
115-
)
116-
except TypeError as e:
117-
if "unexpected keyword argument 'device'" in str(e):
118-
quant_handler = ao_quantizer_class_dict[quantizer](
119-
groupsize=q_kwargs["groupsize"], precision=precision
120-
)
121-
elif "unexpected keyword argument 'precision'" in str(e):
122-
quant_handler = ao_quantizer_class_dict[quantizer](
123-
groupsize=q_kwargs["groupsize"], device=device
124-
)
125-
else:
126-
raise e
124+
q = quantizer_class_dict[quantizer]
125+
named_params = get_named_params(func)
126+
q_kwargs = validate_args(named_params, quantizer, q_kwargs)
127+
128+
# Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
129+
if "tokenizer" in named_params:
130+
q_kwargs["tokenizer"] = tokenizer
131+
quant_handler = q(device=device, precision=precision, **q_kwargs)
132+
133+
# quantize model
127134
model = quant_handler.quantize(model)
128-
else:
129-
model = quantizer_class_dict[quantizer](
130-
model, device=device, tokenizer=tokenizer, **q_kwargs
131-
).quantized_model()
135+
132136

133137

134138
#########################################################################
@@ -137,7 +141,7 @@ def quantize_model(
137141

138142

139143
class QuantHandler:
140-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
144+
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
141145
self.model_ = model
142146
self.device = device
143147
self.tokenizer = tokenizer
@@ -154,13 +158,17 @@ def quantized_model(self) -> nn.Module:
154158
self.model_.load_state_dict(model_updated_state_dict)
155159
return self.model_
156160

161+
def quantize(model: nn.Module) -> nn.Module:
162+
self.model_ = model
163+
return self.quantized_model()
164+
157165

158166
#########################################################################
159167
### wrapper for setting precision as a QuantHandler ###
160168

161169

162170
class PrecisionHandler(QuantHandler):
163-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
171+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
164172
self.model_ = model
165173
self.device = device
166174
self.tokenizer = tokenizer
@@ -169,6 +177,9 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
169177
dtype = name_to_dtype(dtype, device)
170178
self.dtype = dtype
171179

180+
# We simply ignore precision. because dtype is the precision arg as possibly string
181+
# maybe: assert(precision in [self.dtype, None])
182+
172183
def create_quantized_state_dict(self) -> Dict: # "StateDict"
173184
pass
174185

@@ -186,7 +197,7 @@ def quantized_model(self) -> nn.Module:
186197

187198

188199
class ExecutorHandler(QuantHandler):
189-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, accelerator):
200+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
190201
self.model_ = model
191202

192203
if isinstance(accelerator, str):
@@ -573,8 +584,9 @@ def et_forward(self, input: torch.Tensor) -> torch.Tensor:
573584
class WeightOnlyInt8QuantHandler(QuantHandler):
574585
def __init__(
575586
self,
576-
model: nn.Module,
587+
model: Optional[nn.Module] = None,
577588
device,
589+
precision=None,
578590
tokenizer=None,
579591
*,
580592
node_type: str = "*",
@@ -774,8 +786,9 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
774786
class EmbeddingOnlyQuantHandler(QuantHandler):
775787
def __init__(
776788
self,
777-
model: nn.Module,
789+
model: Optional[nn.Module] = None,
778790
device,
791+
precision=None,
779792
tokenizer=None,
780793
*,
781794
bitwidth: int = 8,
@@ -868,9 +881,6 @@ def quantized_model(self) -> nn.Module:
868881
"linear:int8": WeightOnlyInt8QuantHandler,
869882
"precision": PrecisionHandler,
870883
"executor": ExecutorHandler,
871-
}
872-
873-
ao_quantizer_class_dict = {
874884
"linear:int4": Int4WeightOnlyQuantizer,
875885
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
876886
}
@@ -890,7 +900,7 @@ def quantized_model(self) -> nn.Module:
890900
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
891901
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
892902
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
893-
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
903+
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
894904

895905
# Try loading custom op
896906
try:
@@ -903,4 +913,10 @@ def quantized_model(self) -> nn.Module:
903913
print("Slow fallback kernels will be used.")
904914

905915
except Exception as e:
916+
class ErrorHandler(QuantHandler):
917+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
918+
global a8wxdq_load_error
919+
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")
920+
906921
a8wxdq_load_error = e
922+
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler

0 commit comments

Comments
 (0)