Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ce41944

Browse files
Move PT pin and AO commit hash to benefit from SDPA MPS implementation (#964)
* Move PT pin and AO commit hash to benefit from SDPA MPS implementation in PT * Move PT nightly to 2024-07-28 * Shuffle weights prior to _convert_weight_to_int4pack call in gguf_loader * Uninstall torchao on M1 before re-installing it * Update ET pin * Uninstall torchao on M1 before re-installing it * Update ET pin
1 parent fe73ef7 commit ce41944

File tree

5 files changed

+16
-9
lines changed

5 files changed

+16
-9
lines changed

.github/workflows/pull.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -627,11 +627,14 @@ jobs:
627627
runner: macos-m1-stable # neeps MPS, was macos-m1-stable
628628
script: |
629629
set -x
630-
# NS: Remove previous installation of torch first
631-
# as this script does not isntall anything into conda env but rather as system dep
630+
# NS/MC: Remove previous installation of torch and torchao first
631+
# as this script does not install anything into conda env but rather as system dep
632632
pip3 uninstall -y torch || true
633633
set -eou pipefail
634634
635+
pip3 uninstall -y torchao || true
636+
set -eou pipefail
637+
635638
echo "::group::Print machine info"
636639
uname -a
637640
sysctl machdep.cpu.brand_string
@@ -736,10 +739,12 @@ jobs:
736739
runner: macos-m1-stable # needs MPS, was macos-m1-stable
737740
script: |
738741
set -x
739-
# NS: Remove previous installation of torch first
740-
# as this script does not isntall anything into conda env but rather as system dep
742+
# NS/MC: Remove previous installation of torch and torchao first
743+
# as this script does not install anything into conda env but rather as system dep
741744
pip3 uninstall -y torch || true
745+
set -eou pipefail
742746
747+
pip3 uninstall -y torchao || true
743748
set -eou pipefail
744749
745750
echo "::group::Print machine info"

.pins/et-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
c7574994ecd775fdaacc0f2de27089526e05b108
1+
91298923a0076c1b41059efb6dad2876426e4b03

build/gguf_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,9 @@ def load_model_and_state_dict(
170170
if load_state_dict:
171171
q, s, z = Q4_0.unpack(t)
172172
scales_and_zeros = pack_scales_and_zeros(s, z)
173+
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
173174
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
174-
q, inner_k_tiles
175+
q_uint8, inner_k_tiles
175176
)
176177
state_dict[f"{fqn}.weight"] = weight_int4pack
177178
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros

install_requirements.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fi
4747
# NOTE: If a newly-fetched version of the executorch repo changes the value of
4848
# NIGHTLY_VERSION, you should re-run this script to install the necessary
4949
# package versions.
50-
NIGHTLY_VERSION=dev20240710
50+
NIGHTLY_VERSION=dev20240728
5151

5252
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
5353
(
@@ -82,7 +82,7 @@ REQUIREMENTS_TO_INSTALL=(
8282
# TODO: Remove this and install nightly build, once it supports macos
8383
(
8484
set -x
85-
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@d36de1b144b73bf753bd082109c2b5d0141abd5b
85+
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@d477c0e59b458b5617dcb3e999290a87df3070d8
8686
)
8787
if [[ -x "$(command -v nvidia-smi)" ]]; then
8888
(

quantization/qops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,9 @@ def _prepare_weight_and_scales_and_zeros(
401401
weight_int32, scales_and_zeros = group_quantize_tensor(
402402
weight_bf16, n_bit=4, groupsize=groupsize
403403
)
404+
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(torch.uint8)
404405
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
405-
weight_int32, inner_k_tiles
406+
weight_uint8, inner_k_tiles
406407
)
407408
return weight_int4pack, scales_and_zeros
408409

0 commit comments

Comments
 (0)