Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3d63968

Browse files
committed
bump torchao pin
1 parent 7d5ba09 commit 3d63968

File tree

4 files changed

+44
-25
lines changed

4 files changed

+44
-25
lines changed

.github/workflows/pull.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,14 +1132,14 @@ jobs:
11321132
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
11331133
export PRMT="Once upon a time in a land far away"
11341134
echo "Generate eager"
1135-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1135+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11361136
echo "Generate compile"
1137-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile
1137+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
11381138
echo "Export and run ET (C++ runner)"
1139-
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1139+
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11401140
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11411141
echo "Export and run AOTI (C++ runner)"
1142-
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1142+
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11431143
./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}"
11441144
echo "Generate AOTI"
11451145
python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}"

docs/quantization.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,29 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
121121
## Experimental TorchAO lowbit kernels
122122

123123
### Use
124-
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
124+
125+
#### linear:a8wxdq
126+
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
125127
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
126128
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
127129
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
128130

129-
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
131+
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, 5, or 6 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
132+
133+
#### embedding:wx
134+
The quantization scheme embedding:wx quantizes embeddings in a groupwise manner with the specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize. Unlike linear:a8wxdq, embedding:wx always quantizes with scales and zeros.
135+
136+
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, 5, or 6 and groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
130137

131138
### Setup
132-
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
139+
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
133140

134141
From the torchchat root directory, run
135142
```
136143
sh torchchat/utils/scripts/build_torchao_ops.sh
137144
```
138145

139-
This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.
146+
This should take about 10 seconds to complete.
140147

141148
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
142149

@@ -156,17 +163,17 @@ Below we show how to use the new kernels. Except for ExecuTorch, you can specif
156163

157164
#### Eager mode
158165
```
159-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
166+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
160167
```
161168

162169
#### torch.compile
163170
```
164-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
171+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
165172
```
166173

167174
#### AOTI
168175
```
169-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so
176+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
170177
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
171178
```
172179

@@ -178,7 +185,7 @@ OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cac
178185

179186
#### ExecuTorch
180187
```
181-
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte
188+
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-pte llama3_1.pte
182189
```
183190

184191
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.

torchchat/utils/quantize.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353

5454
# Flag for whether the a8wxdq quantizer is available.
55-
a8wxdq_load_error: Optional[Exception] = None
55+
torchao_experimental_load_error: Optional[Exception] = None
5656

5757
#########################################################################
5858
### handle arg validation ###
@@ -887,24 +887,35 @@ def quantized_model(self) -> nn.Module:
887887

888888
try:
889889
import importlib.util
890-
import sys
891890
import os
891+
import sys
892+
892893
torchao_build_path = f"{os.getcwd()}/torchao-build"
893894

894895
# Try loading quantizer
895896
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
896897
"torchao_experimental_quant_api",
897898
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
898899
)
899-
torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec)
900+
torchao_experimental_quant_api = importlib.util.module_from_spec(
901+
torchao_experimental_quant_api_spec
902+
)
900903
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
901-
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
902-
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
903-
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
904+
torchao_experimental_quant_api_spec.loader.exec_module(
905+
torchao_experimental_quant_api
906+
)
907+
from torchao_experimental_quant_api import (
908+
Int8DynActIntxWeightLinearQuantizer,
909+
IntxWeightEmbeddingQuantizer,
910+
)
911+
912+
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
913+
ao_quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
904914

905915
# Try loading custom op
906916
try:
907917
import glob
918+
908919
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
909920
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
910921
torch.ops.load_library(libs[0])
@@ -915,8 +926,9 @@ def quantized_model(self) -> nn.Module:
915926
except Exception as e:
916927
class ErrorHandler(QuantHandler):
917928
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
918-
global a8wxdq_load_error
919-
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")
929+
global torchao_experimental_load_error
930+
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")
920931

921932
a8wxdq_load_error = e
922933
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
934+
quantizer_class_dict["embedding:wx"] = ErrorHandler

torchchat/utils/scripts/install_utils.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,10 @@ clone_torchao() {
176176
pushd ${TORCHCHAT_ROOT}/torchao-build/src
177177
echo $pwd
178178

179-
git clone https://github.com/pytorch/ao.git
180-
cd ao
181-
git checkout $(cat ${TORCHCHAT_ROOT}/install/.pins/torchao-pin.txt)
179+
# git clone https://github.com/pytorch/ao.git
180+
# cd ao
181+
# git checkout $(cat ${TORCHCHAT_ROOT}/install/.pins/torchao-pin.txt)
182+
cp -R $HOME/fbsource/fbcode/pytorch/ao .
182183

183184
popd
184185
}
@@ -191,7 +192,6 @@ install_torchao_aten_ops() {
191192
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
192193
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
193194
-DCMAKE_BUILD_TYPE="Release" \
194-
-DTORCHAO_OP_TARGET="aten" \
195195
-S . \
196196
-B ${CMAKE_OUT_DIR} -G Ninja
197197
cmake --build ${CMAKE_OUT_DIR} --target install --config Release
@@ -207,7 +207,7 @@ install_torchao_executorch_ops() {
207207
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
208208
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
209209
-DCMAKE_BUILD_TYPE="Release" \
210-
-DTORCHAO_OP_TARGET="executorch" \
210+
-DTORCHAO_BUILD_EXECUTORCH_OPS=ON \
211211
-DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \
212212
-DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \
213213
-S . \

0 commit comments

Comments
 (0)