diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index c48436a80..d25c674dd 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -1124,3 +1124,41 @@ jobs: echo "Generate AOTI" python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}" echo "Tests complete." + + test-torchao-experimental-mps: + strategy: + matrix: + runner: [macos-m1-stable] + runs-on: ${{matrix.runner}} + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.10.11 + - name: Print machine info + run: | + uname -a + if [ $(uname -s) == Darwin ]; then + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + fi + - name: Install torchchat + run: | + echo "Intalling pip3 packages" + ./install/install_requirements.sh + pip3 list + python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' + - name: Install torchao-ops-mps + id: install-torchao-ops-mps + run: | + bash torchchat/utils/scripts/build_torchao_ops.sh mps + - name: Run inference + run: | + python torchchat.py download stories110M + export PRMT="Once upon a time in a land far away" + echo "Generate eager" + python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device mps --dtype float32 --quantize '{"linear:afpwx": {"bitwidth": 3, "groupsize": 32}}' diff --git a/docs/quantization.md b/docs/quantization.md index 5007946bb..08086d8d1 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -196,6 +196,32 @@ Note: only the ExecuTorch C++ runner in torchchat when built using the instructi ./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time," ``` +## Experimental TorchAO MPS lowbit kernels + +WARNING: These kernels only work on devices with Apple Silicon. + +### Use + +#### linear:afpwx +The quantization scheme linear:afpwx quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize. +It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize (32, 64, 128, 256). + +### Setup +To use linear:afpwx, you must set up the torchao mps experimental kernels. These will only work on device with Apple Silicon. +Currently, torchchat can only run them on Eager mode. + +From the torchchat root directory, run +``` +sh torchchat/utils/scripts/build_torchao_ops.sh mps +``` + +### Examples + +#### Eager mode +``` +python3 torchchat.py generate stories110M --device mps --dtype float32 --quantize '{"linear:afpwx": {"bitwidth": 4, "groupsize": 256}}' --prompt "Once upon a time," --num-samples 5 +``` + ## Quantization Profiles Four [sample profiles](https://github.com/pytorch/torchchat/tree/main/torchchat/quant_config/) are included with the torchchat distribution: `cuda.json`, `desktop.json`, `mobile.json`, `pi5.json` diff --git a/install/.pins/torchao-pin.txt b/install/.pins/torchao-pin.txt index 40f083249..80a4751bc 100644 --- a/install/.pins/torchao-pin.txt +++ b/install/.pins/torchao-pin.txt @@ -1 +1 @@ -c8f1174a06dcc0102849c8348ca6573bde8847a9 +7d7c14e898eca3fe66138d2a9445755a9270b800 diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 31c639dfd..6ac2410d0 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -63,10 +63,10 @@ def get_named_parameters(func: Callable) -> List[str]: # Get the signature of the function signature = inspect.signature(func) - + # Extract the parameters from the signature parameters = signature.parameters - + # Filter and return named parameters named_params = [ name for name, param in parameters.items() @@ -80,8 +80,8 @@ def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") del q_kwargs[key] return q_kwargs - - + + ######################################################################### ### torchchat quantization API ### @@ -116,15 +116,18 @@ def quantize_model( if not support_tensor_subclass: unwrap_tensor_subclass(model) continue - + if quantizer in ["linear:a8wxdq", "embedding:wx"]: # These quantizers require float32 input weights. Note that after quantization, # the weights will no longer be float32, but lowbit integers if get_precision() != torch.float32: print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") set_precision(torch.float32) - - # We set global precision from quantize options if it is specified at cli.py:485 + + if quantizer == "linear:afpwx" and device != "mps": + raise RuntimeError("linear:afpwx quantization can only run on mps device!") + + # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat precision = get_precision() @@ -915,10 +918,12 @@ def quantized_model(self) -> nn.Module: from torchao_experimental_quant_api import ( Int8DynActIntxWeightLinearQuantizer, IntxWeightEmbeddingQuantizer, + UIntxWeightOnlyLinearQuantizer, ) quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer + quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer # Try loading custom op try: @@ -928,15 +933,14 @@ def quantized_model(self) -> nn.Module: libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) torch.ops.load_library(libs[0]) except Exception as e: - print("Failed to torchao ops library with error: ", e) - print("Slow fallback kernels will be used.") + print("Unabled to load torchao cpu ops library. Slow fallback kernels will be used.") + + try: + libname = "libtorchao_ops_mps_aten.dylib" + libpath = f"{torchao_build_path}/cmake-out/lib/{libname}" + torch.ops.load_library(libpath) + except Exception as e: + print("Unabled to load torchao mps ops library.") except Exception as e: - class ErrorHandler(QuantHandler): - def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None): - global torchao_experimental_load_error - raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}") - - torchao_experimental_load_error = e - quantizer_class_dict["linear:a8wxdq"] = ErrorHandler - quantizer_class_dict["embedding:wx"] = ErrorHandler + print("Unabled to import torchao experimental quant_api with error: ", e) diff --git a/torchchat/utils/scripts/build_torchao_ops.sh b/torchchat/utils/scripts/build_torchao_ops.sh index a8fd8bea2..46e2479ac 100644 --- a/torchchat/utils/scripts/build_torchao_ops.sh +++ b/torchchat/utils/scripts/build_torchao_ops.sh @@ -5,12 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +device=${1:-cpu} +if [[ "$device" != "cpu" && "$device" != "mps" ]]; then + echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2 + exit 1 +fi source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh" pushd ${TORCHCHAT_ROOT} find_cmake_prefix_path clone_torchao -install_torchao_aten_ops +install_torchao_aten_ops "$device" popd diff --git a/torchchat/utils/scripts/install_utils.sh b/torchchat/utils/scripts/install_utils.sh index 84966cc35..94378960a 100644 --- a/torchchat/utils/scripts/install_utils.sh +++ b/torchchat/utils/scripts/install_utils.sh @@ -184,8 +184,18 @@ clone_torchao() { } install_torchao_aten_ops() { - echo "Building torchao custom ops for ATen" - pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental + local device=${1:-cpu} + + if [[ "$device" == "cpu" ]]; then + echo "Building torchao custom ops for ATen" + pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental + elif [[ "$device" == "mps" ]]; then + echo "Building torchao mps custom ops for ATen" + pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental/ops/mps + else + echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2 + return 1 + fi CMAKE_OUT_DIR=${TORCHCHAT_ROOT}/torchao-build/cmake-out cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \