Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit bdac616

Browse files
committed
update experimental kernels in torchchat
1 parent 5684175 commit bdac616

File tree

4 files changed

+116
-30
lines changed

4 files changed

+116
-30
lines changed

.github/workflows/pull.yml

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,54 @@ jobs:
10551055
./runner/build_android.sh
10561056
echo "Tests complete."
10571057
1058-
test-torchao-experimental:
1058+
test-torchao-experimental-python:
1059+
strategy:
1060+
matrix:
1061+
runner: [macos-14-xlarge]
1062+
runs-on: ${{matrix.runner}}
1063+
steps:
1064+
- name: Checkout repo
1065+
uses: actions/checkout@v3
1066+
with:
1067+
submodules: true
1068+
- name: Setup Python
1069+
uses: actions/setup-python@v2
1070+
with:
1071+
python-version: 3.10.11
1072+
- name: Setup Xcode
1073+
if: runner.os == 'macOS'
1074+
uses: maxim-lobanov/setup-xcode@v1
1075+
with:
1076+
xcode-version: '15.3'
1077+
- name: Print machine info
1078+
run: |
1079+
uname -a
1080+
if [ $(uname -s) == Darwin ]; then
1081+
sysctl machdep.cpu.brand_string
1082+
sysctl machdep.cpu.core_count
1083+
fi
1084+
- name: Install torchchat
1085+
run: |
1086+
echo "Intalling pip3 packages"
1087+
./install/install_requirements.sh
1088+
pip3 list
1089+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
1090+
- name: Run inference
1091+
run: |
1092+
python torchchat.py download stories110M
1093+
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
1094+
export PRMT="Once upon a time in a land far away"
1095+
echo "Generate eager"
1096+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1097+
echo "Generate compile"
1098+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
1099+
echo "Export AOTI"
1100+
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1101+
echo "Generate AOTI"
1102+
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
1103+
echo "Tests complete."
1104+
1105+
test-torchao-experimental-cpp:
10591106
strategy:
10601107
matrix:
10611108
runner: [macos-14-xlarge]
@@ -1109,18 +1156,12 @@ jobs:
11091156
python torchchat.py download stories110M
11101157
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
11111158
export PRMT="Once upon a time in a land far away"
1112-
echo "Generate eager"
1113-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1114-
echo "Generate compile"
1115-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
11161159
echo "Export and run ET (C++ runner)"
11171160
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11181161
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11191162
echo "Export and run AOTI (C++ runner)"
11201163
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11211164
./cmake-out/aoti_run ./model.pt2 -z ./tokenizer.model -t 0 -i "${PRMT}"
1122-
echo "Generate AOTI"
1123-
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
11241165
echo "Tests complete."
11251166
11261167
test-torchao-experimental-mps:

docs/quantization.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,15 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
120120

121121
## Experimental TorchAO lowbit kernels
122122

123-
WARNING: These kernels only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
123+
If you are on a Mac with Apple Silicon, we have 1-8 quantization available for embedding and linear layers, backed by CPU and MPS kernels.
124+
125+
The CPU kernels are installed automatically by the torchchat install script and can be used out of the box. To use the MPS kernels, follow the setup instructions below.
124126

125127
### Use
126128

127129
#### linear:a8wxdq
128130
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
129-
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
131+
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7, 8), groupsize (-1 if channelwise desired), and has_weight_zeros (true, false).
130132
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
131133
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
132134

@@ -138,7 +140,9 @@ The quantization scheme embedding:wx quantizes embeddings in a groupwise manner
138140
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
139141

140142
### Setup
141-
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
143+
If you are using the torchao ops from python, they are available out of the box on a Mac with Apple Silicon, and you can skip these setup steps.
144+
145+
If you plan to use the kernels from the AOTI/ExecuTorch C++ runners, follow the setup steps below.
142146

143147
From the torchchat root directory, run
144148
```
@@ -147,7 +151,7 @@ bash torchchat/utils/scripts/build_torchao_ops.sh
147151

148152
This should take about 10 seconds to complete.
149153

150-
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
154+
When building the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
151155

152156
```
153157
bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
@@ -175,8 +179,8 @@ OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype fl
175179

176180
#### AOTI
177181
```
178-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
179-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
182+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-aoti-package-path llama3_1.pt2
183+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --aoti-package-path llama3_1.pt2 --prompt "Once upon a time," --num-samples 5
180184
```
181185

182186
If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner:

install/install_requirements.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ fi
117117

118118
# For torchao need to install from github since nightly build doesn't have macos build.
119119
# TODO: Remove this and install nightly build, once it supports macos
120+
# USE_CPP=1 indicates that the torchao experimental aten kernels will be built and loaded
121+
# if on Mac with Apple Silicon
120122
(
121123
set -x
122-
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@2f97b0955953fa1a46594a27f0df2bc48d93e79d
124+
USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@11333ba2cb5c4e792bc4f5c0d70c12991f972008
123125
)
124126

125127
if [[ -x "$(command -v nvidia-smi)" ]]; then

torchchat/utils/quantize.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@
5050
state_dict_device,
5151
use_et_backend,
5252
)
53+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
54+
PackedLinearInt8DynamicActivationIntxWeightLayout,
55+
)
56+
from torchao.experimental.quant_api import (
57+
int8_dynamic_activation_intx_weight,
58+
IntxWeightEmbeddingQuantizer,
59+
)
60+
from torchao.quantization.granularity import (
61+
PerGroup,
62+
PerRow,
63+
)
64+
from torchao.dtypes import PlainLayout
5365

5466

5567
# Flag for whether the a8wxdq quantizer is available.
@@ -117,7 +129,47 @@ def quantize_model(
117129
unwrap_tensor_subclass(model)
118130
continue
119131

120-
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
132+
if quantizer == "linear:a8wxdq":
133+
if get_precision() != torch.float32:
134+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
135+
set_precision(torch.float32)
136+
137+
group_size = q_kwargs["groupsize"]
138+
bit_width = q_kwargs["bitwidth"]
139+
has_weight_zeros = q_kwargs["has_weight_zeros"]
140+
granularity = PerRow()
141+
if group_size != -1:
142+
granularity = PerGroup(group_size)
143+
weight_dtype = getattr(torch, f"int{bit_width}")
144+
145+
try:
146+
quantize_(
147+
model,
148+
int8_dynamic_activation_intx_weight(
149+
weight_dtype=weight_dtype,
150+
granularity=granularity,
151+
has_weight_zeros=has_weight_zeros,
152+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
153+
),
154+
)
155+
except Exception as e:
156+
print("Encountered error during quantization: {e}")
157+
print("Trying with PlainLayout")
158+
quantize_(
159+
model,
160+
int8_dynamic_activation_intx_weight(
161+
weight_dtype=weight_dtype,
162+
granularity=granularity,
163+
has_weight_zeros=has_weight_zeros,
164+
layout=PlainLayout(),
165+
),
166+
)
167+
168+
if not support_tensor_subclass:
169+
unwrap_tensor_subclass(model)
170+
continue
171+
172+
if quantizer == "embedding:wx":
121173
# These quantizers require float32 input weights. Note that after quantization,
122174
# the weights will no longer be float32, but lowbit integers
123175
if get_precision() != torch.float32:
@@ -889,10 +941,12 @@ def quantized_model(self) -> nn.Module:
889941
# class references
890942
quantizer_class_dict = {
891943
"embedding": EmbeddingOnlyQuantHandler,
944+
"embedding:wx": IntxWeightEmbeddingQuantizer,
892945
"linear:int8": WeightOnlyInt8QuantHandler,
893946
"precision": PrecisionHandler,
894947
"executor": ExecutorHandler,
895948
"linear:int4": Int4WeightOnlyQuantizer,
949+
"linear:a8wxdq": None, # uses quantize_ API
896950
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
897951
}
898952

@@ -916,26 +970,11 @@ def quantized_model(self) -> nn.Module:
916970
torchao_experimental_quant_api
917971
)
918972
from torchao_experimental_quant_api import (
919-
Int8DynActIntxWeightLinearQuantizer,
920-
IntxWeightEmbeddingQuantizer,
921973
UIntxWeightOnlyLinearQuantizer,
922974
)
923-
924-
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
925-
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
926975
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer
927976

928977
# Try loading custom op
929-
try:
930-
import glob
931-
932-
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
933-
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
934-
torch.ops.load_library(libs[0])
935-
print("Loaded torchao cpu ops.")
936-
except Exception as e:
937-
print("Unable to load torchao cpu ops library. Slow fallback kernels will be used.")
938-
939978
try:
940979
libname = "libtorchao_ops_mps_aten.dylib"
941980
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"

0 commit comments

Comments
 (0)