Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5d257d7

Browse files
committed
update experimental kernels in torchchat
1 parent 654bb03 commit 5d257d7

File tree

4 files changed

+116
-30
lines changed

4 files changed

+116
-30
lines changed

.github/workflows/pull.yml

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,54 @@ jobs:
10371037
./runner/build_android.sh
10381038
echo "Tests complete."
10391039
1040-
test-torchao-experimental:
1040+
test-torchao-experimental-python:
1041+
strategy:
1042+
matrix:
1043+
runner: [macos-14-xlarge]
1044+
runs-on: ${{matrix.runner}}
1045+
steps:
1046+
- name: Checkout repo
1047+
uses: actions/checkout@v3
1048+
with:
1049+
submodules: true
1050+
- name: Setup Python
1051+
uses: actions/setup-python@v2
1052+
with:
1053+
python-version: 3.10.11
1054+
- name: Setup Xcode
1055+
if: runner.os == 'macOS'
1056+
uses: maxim-lobanov/setup-xcode@v1
1057+
with:
1058+
xcode-version: '15.3'
1059+
- name: Print machine info
1060+
run: |
1061+
uname -a
1062+
if [ $(uname -s) == Darwin ]; then
1063+
sysctl machdep.cpu.brand_string
1064+
sysctl machdep.cpu.core_count
1065+
fi
1066+
- name: Install torchchat
1067+
run: |
1068+
echo "Intalling pip3 packages"
1069+
./install/install_requirements.sh
1070+
pip3 list
1071+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
1072+
- name: Run inference
1073+
run: |
1074+
python torchchat.py download stories110M
1075+
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
1076+
export PRMT="Once upon a time in a land far away"
1077+
echo "Generate eager"
1078+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1079+
echo "Generate compile"
1080+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
1081+
echo "Export AOTI"
1082+
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1083+
echo "Generate AOTI"
1084+
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
1085+
echo "Tests complete."
1086+
1087+
test-torchao-experimental-cpp:
10411088
strategy:
10421089
matrix:
10431090
runner: [macos-14-xlarge]
@@ -1091,18 +1138,12 @@ jobs:
10911138
python torchchat.py download stories110M
10921139
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
10931140
export PRMT="Once upon a time in a land far away"
1094-
echo "Generate eager"
1095-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1096-
echo "Generate compile"
1097-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
10981141
echo "Export and run ET (C++ runner)"
10991142
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11001143
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11011144
echo "Export and run AOTI (C++ runner)"
11021145
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11031146
./cmake-out/aoti_run ./model.pt2 -z ./tokenizer.model -t 0 -i "${PRMT}"
1104-
echo "Generate AOTI"
1105-
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
11061147
echo "Tests complete."
11071148
11081149
test-torchao-experimental-mps:

docs/quantization.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,15 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
120120

121121
## Experimental TorchAO lowbit kernels
122122

123-
WARNING: These kernels only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
123+
If you are on a Mac with Apple Silicon, we have 1-8 quantization available for embedding and linear layers, backed by CPU and MPS kernels.
124+
125+
The CPU kernels are installed automatically by the torchchat install script and can be used out of the box. To use the MPS kernels, follow the setup instructions below.
124126

125127
### Use
126128

127129
#### linear:a8wxdq
128130
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
129-
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
131+
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7, 8), groupsize (-1 if channelwise desired), and has_weight_zeros (true, false).
130132
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
131133
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
132134

@@ -138,7 +140,9 @@ The quantization scheme embedding:wx quantizes embeddings in a groupwise manner
138140
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
139141

140142
### Setup
141-
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
143+
If you are using the torchao ops from python, they are available out of the box on a Mac with Apple Silicon, and you can skip these setup steps.
144+
145+
If you plan to use the kernels from the AOTI/ExecuTorch C++ runners, follow the setup steps below.
142146

143147
From the torchchat root directory, run
144148
```
@@ -147,7 +151,7 @@ bash torchchat/utils/scripts/build_torchao_ops.sh
147151

148152
This should take about 10 seconds to complete.
149153

150-
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
154+
When building the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
151155

152156
```
153157
bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
@@ -175,8 +179,8 @@ OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype fl
175179

176180
#### AOTI
177181
```
178-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
179-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
182+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-aoti-package-path llama3_1.pt2
183+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --aoti-package-path llama3_1.pt2 --prompt "Once upon a time," --num-samples 5
180184
```
181185

182186
If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner:

install/install_requirements.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ REQUIREMENTS_TO_INSTALL=(
105105

106106
# For torchao need to install from github since nightly build doesn't have macos build.
107107
# TODO: Remove this and install nightly build, once it supports macos
108+
# USE_CPP=1 indicates that the torchao experimental aten kernels will be built and loaded
109+
# if on Mac with Apple Silicon
108110
(
109111
set -x
110-
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@2f97b0955953fa1a46594a27f0df2bc48d93e79d
112+
USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@11333ba2cb5c4e792bc4f5c0d70c12991f972008
111113
)
112114

113115
if [[ -x "$(command -v nvidia-smi)" ]]; then

torchchat/utils/quantize.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@
5050
state_dict_device,
5151
use_et_backend,
5252
)
53+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
54+
PackedLinearInt8DynamicActivationIntxWeightLayout,
55+
)
56+
from torchao.experimental.quant_api import (
57+
int8_dynamic_activation_intx_weight,
58+
IntxWeightEmbeddingQuantizer,
59+
)
60+
from torchao.quantization.granularity import (
61+
PerGroup,
62+
PerRow,
63+
)
64+
from torchao.dtypes import PlainLayout
5365

5466

5567
# Flag for whether the a8wxdq quantizer is available.
@@ -117,7 +129,47 @@ def quantize_model(
117129
unwrap_tensor_subclass(model)
118130
continue
119131

120-
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
132+
if quantizer == "linear:a8wxdq":
133+
if get_precision() != torch.float32:
134+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
135+
set_precision(torch.float32)
136+
137+
group_size = q_kwargs["groupsize"]
138+
bit_width = q_kwargs["bitwidth"]
139+
has_weight_zeros = q_kwargs["has_weight_zeros"]
140+
granularity = PerRow()
141+
if group_size != -1:
142+
granularity = PerGroup(group_size)
143+
weight_dtype = getattr(torch, f"int{bit_width}")
144+
145+
try:
146+
quantize_(
147+
model,
148+
int8_dynamic_activation_intx_weight(
149+
weight_dtype=weight_dtype,
150+
granularity=granularity,
151+
has_weight_zeros=has_weight_zeros,
152+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
153+
),
154+
)
155+
except Exception as e:
156+
print("Encountered error during quantization: {e}")
157+
print("Trying with PlainLayout")
158+
quantize_(
159+
model,
160+
int8_dynamic_activation_intx_weight(
161+
weight_dtype=weight_dtype,
162+
granularity=granularity,
163+
has_weight_zeros=has_weight_zeros,
164+
layout=PlainLayout(),
165+
),
166+
)
167+
168+
if not support_tensor_subclass:
169+
unwrap_tensor_subclass(model)
170+
continue
171+
172+
if quantizer == "embedding:wx":
121173
# These quantizers require float32 input weights. Note that after quantization,
122174
# the weights will no longer be float32, but lowbit integers
123175
if get_precision() != torch.float32:
@@ -889,10 +941,12 @@ def quantized_model(self) -> nn.Module:
889941
# class references
890942
quantizer_class_dict = {
891943
"embedding": EmbeddingOnlyQuantHandler,
944+
"embedding:wx": IntxWeightEmbeddingQuantizer,
892945
"linear:int8": WeightOnlyInt8QuantHandler,
893946
"precision": PrecisionHandler,
894947
"executor": ExecutorHandler,
895948
"linear:int4": Int4WeightOnlyQuantizer,
949+
"linear:a8wxdq": None, # uses quantize_ API
896950
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
897951
}
898952

@@ -916,26 +970,11 @@ def quantized_model(self) -> nn.Module:
916970
torchao_experimental_quant_api
917971
)
918972
from torchao_experimental_quant_api import (
919-
Int8DynActIntxWeightLinearQuantizer,
920-
IntxWeightEmbeddingQuantizer,
921973
UIntxWeightOnlyLinearQuantizer,
922974
)
923-
924-
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
925-
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
926975
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer
927976

928977
# Try loading custom op
929-
try:
930-
import glob
931-
932-
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
933-
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
934-
torch.ops.load_library(libs[0])
935-
print("Loaded torchao cpu ops.")
936-
except Exception as e:
937-
print("Unable to load torchao cpu ops library. Slow fallback kernels will be used.")
938-
939978
try:
940979
libname = "libtorchao_ops_mps_aten.dylib"
941980
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"

0 commit comments

Comments
 (0)