Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5c74843

Browse files
committed
add new torchao experimental kernels to torchchat
1 parent 6fae164 commit 5c74843

File tree

8 files changed

+129
-6
lines changed

8 files changed

+129
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ __pycache__/
1414
# Build directories
1515
build/android/*
1616
et-build/*
17+
torchao-build/*
1718
runner-et/cmake-out/*
1819
runner-aoti/cmake-out/*
1920
cmake-out/
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3fa38aaf1276e36845a82fb399e5054718a441c4

runner/aoti.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ if(Torch_FOUND)
2828
target_link_libraries(aoti_run "${TORCH_LIBRARIES}" m)
2929
set_property(TARGET aoti_run PROPERTY CXX_STANDARD 17)
3030
endif()
31+
32+
if (LINK_TORCHAO_CUSTOM_OPS)
33+
target_link_libraries(aoti_run "${TORCHCHAT_ROOT}/torchao-build/cmake-out/liblowbit_op_aten${CMAKE_SHARED_LIBRARY_SUFFIX}")
34+
endif()

runner/et.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ if(executorch_FOUND)
111111
target_link_libraries(et_run PRIVATE log)
112112
endif()
113113

114+
if(LINK_TORCHAO_CUSTOM_OPS)
115+
target_link_libraries(et_run PRIVATE "${TORCHCHAT_ROOT}/torchao-build/cmake-out/liblowbit_op_executorch${CMAKE_SHARED_LIBRARY_SUFFIX}")
116+
endif()
117+
114118
# Adding target_link_options_shared_lib as commented out below leads to this:
115119
#
116120
# CMake Error at Utils.cmake:22 (target_link_options):

torchchat/utils/quantize.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,19 @@ def quantize_model(
9696
precision = get_precision()
9797

9898
try:
99-
# Easier to ask forgiveness than permission
100-
quant_handler = ao_quantizer_class_dict[quantizer](
101-
groupsize=q_kwargs["groupsize"], device=device, precision=precision
102-
)
99+
if quantizer == "linear:a8wxdq":
100+
quant_handler = ao_quantizer_class_dict[quantizer](
101+
device=device,
102+
precision=precision,
103+
bitwidth=q_kwargs.get("bitwidth", 4),
104+
groupsize=q_kwargs.get("groupsize", 128),
105+
has_weight_zeros=q_kwargs.get("has_weight_zeros", False),
106+
)
107+
else:
108+
# Easier to ask forgiveness than permission
109+
quant_handler = ao_quantizer_class_dict[quantizer](
110+
groupsize=q_kwargs["groupsize"], device=device, precision=precision
111+
)
103112
except TypeError as e:
104113
if "unexpected keyword argument 'device'" in str(e):
105114
quant_handler = ao_quantizer_class_dict[quantizer](
@@ -861,3 +870,33 @@ def quantized_model(self) -> nn.Module:
861870
"linear:int4": Int4WeightOnlyQuantizer,
862871
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
863872
}
873+
874+
try:
875+
import importlib.util
876+
import sys
877+
import os
878+
torchao_build_path = f"{os.getcwd()}/torchao-build"
879+
880+
# Try loading quantizer
881+
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
882+
"torchao_experimental_quant_api",
883+
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
884+
)
885+
torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec)
886+
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
887+
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
888+
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
889+
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
890+
891+
# Try loading custom op
892+
try:
893+
import glob
894+
libs = glob.glob(f"{torchao_build_path}/cmake-out/liblowbit_op_aten.*")
895+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
896+
torch.ops.load_library(libs[0])
897+
except Exception as e:
898+
print("Failed to torchao custom op library with error: ", e)
899+
print("Slow fallback kernels will be used.")
900+
901+
except Exception as e:
902+
print(f"Failed to load torchao experimental a8wxdq quantizer with error: {e}")

torchchat/utils/scripts/build_native.sh

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ if [ $# -eq 0 ]; then
2525
show_help
2626
exit 1
2727
fi
28+
29+
LINK_TORCHAO=OFF
2830
while (( "$#" )); do
2931
case "$1" in
3032
-h|--help)
@@ -41,6 +43,11 @@ while (( "$#" )); do
4143
TARGET="et"
4244
shift
4345
;;
46+
link_torchao)
47+
echo "Linking with torchao custom ops..."
48+
LINK_TORCHAO=ON
49+
shift
50+
;;
4451
*)
4552
echo "Invalid option: $1"
4653
show_help
@@ -72,14 +79,20 @@ if [[ "$TARGET" == "et" ]]; then
7279
install_pip_dependencies
7380
clone_executorch
7481
install_executorch_libs false
82+
83+
if [[ "$LINK_TORCHAO" == "ON" ]]; then
84+
EXECUTORCH_INCLUDE_DIRS="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src"
85+
EXECUTORCH_LIBRARIES="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libexecutorch_no_prim_ops.a"
86+
install_torchao_custom_executorch_ops
87+
fi
7588
fi
7689
popd
7790

7891
# CMake commands
7992
if [[ "$TARGET" == "et" ]]; then
80-
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja
93+
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_CUSTOM_OPS="${LINK_TORCHAO}" -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja
8194
else
82-
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja
95+
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_CUSTOM_OPS="${LINK_TORCHAO}" -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja
8396
fi
8497
cmake --build ./cmake-out --target "${TARGET}"_run
8598

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
10+
source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"
11+
12+
pushd ${TORCHCHAT_ROOT}
13+
find_cmake_prefix_path
14+
clone_torchao
15+
install_torchao_custom_aten_ops
16+
popd

torchchat/utils/scripts/install_utils.sh

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,48 @@ install_executorch_libs() {
162162

163163
install_executorch_python_libs $1
164164
}
165+
166+
clone_torchao() {
167+
echo "Cloning torchao to ${TORCHCHAT_ROOT}/torchao-build/src"
168+
rm -rf ${TORCHCHAT_ROOT}/torchao-build/src
169+
mkdir -p ${TORCHCHAT_ROOT}/torchao-build/src
170+
pushd ${TORCHCHAT_ROOT}/torchao-build/src
171+
echo $pwd
172+
173+
cp -R /Users/scroy/fbsource/fbcode/pytorch/ao .
174+
# git clone https://github.com/pytorch/ao.git
175+
# cd ao
176+
# git checkout $(cat ${TORCHCHAT_ROOT}/intstall/.pins/torchao-experimental-pin.txt)
177+
178+
popd
179+
}
180+
181+
install_torchao_custom_aten_ops() {
182+
echo "Building torchao custom ops for ATen"
183+
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op
184+
export TORCHAO_INCLUDE_DIRS=${TORCHCHAT_ROOT}/torchao-build/src/ao
185+
186+
CMAKE_OUT_DIR=${TORCHCHAT_ROOT}/torchao-build/cmake-out
187+
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
188+
-DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
189+
-DPLATFORM="ATEN" \
190+
-S . \
191+
-B ${CMAKE_OUT_DIR} -G Ninja
192+
cmake --build ${CMAKE_OUT_DIR}
193+
}
194+
195+
install_torchao_custom_executorch_ops() {
196+
echo "Building torchao custom ops for ExecuTorch"
197+
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op
198+
export TORCHAO_INCLUDE_DIRS=${TORCHCHAT_ROOT}/torchao-build/src/ao
199+
200+
CMAKE_OUT_DIR="${TORCHCHAT_ROOT}/torchao-build/cmake-out"
201+
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
202+
-DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
203+
-DEXECUTORCH_INCLUDE_DIRS=${EXECUTORCH_INCLUDE_DIRS} \
204+
-DEXECUTORCH_LIBRARIES=${EXECUTORCH_LIBRARIES} \
205+
-DPLATFORM="EXECUTORCH" \
206+
-S . \
207+
-B ${CMAKE_OUT_DIR} -G Ninja
208+
cmake --build ${CMAKE_OUT_DIR}
209+
}

0 commit comments

Comments
 (0)