Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 90ab9e3

Browse files
committed
Add torchao (#1182)
* init * update install utils * update * update libs * update torchao pin * fix ci test * add python et install to ci * fix ci errors * fixes * fixes * fixes * fixes * fixes * fixes * fixes
1 parent 68b4631 commit 90ab9e3

File tree

3 files changed

+163
-133
lines changed

3 files changed

+163
-133
lines changed

.github/workflows/pull.yml

Lines changed: 153 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -621,71 +621,87 @@ jobs:
621621
python torchchat.py remove stories15m
622622
623623
test-mps:
624-
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
625-
with:
626-
runner: macos-m1-stable # neeps MPS, was macos-m1-stable
627-
script: |
628-
export PYTHON_VERSION="3.10"
629-
set -x
630-
# NS/MC: Remove previous installation of torch and torchao first
631-
# as this script does not install anything into conda env but rather as system dep
632-
pip3 uninstall -y torch || true
633-
set -eou pipefail
634-
635-
pip3 uninstall -y torchao || true
636-
set -eou pipefail
637-
638-
echo "::group::Print machine info"
639-
uname -a
640-
sysctl machdep.cpu.brand_string
641-
sysctl machdep.cpu.core_count
642-
echo "::endgroup::"
624+
strategy:
625+
matrix:
626+
runner: [macos-m1-stable ]
627+
runs-on: ${{matrix.runner}}
628+
steps:
629+
- name: Checkout repo
630+
uses: actions/checkout@v2
631+
- name: Setup Python
632+
uses: actions/setup-python@v2
633+
with:
634+
python-version: 3.10.11
635+
- name: Print machine info
636+
run: |
637+
uname -a
638+
if [ $(uname -s) == Darwin ]; then
639+
sysctl machdep.cpu.brand_string
640+
sysctl machdep.cpu.core_count
641+
fi
642+
- name: Run test
643+
run: |
644+
export PYTHON_VERSION="3.10"
645+
set -x
646+
# NS/MC: Remove previous installation of torch and torchao first
647+
# as this script does not install anything into conda env but rather as system dep
648+
pip3 uninstall -y torch || true
649+
set -eou pipefail
643650
644-
echo "::group::Install requirements"
645-
# Install requirements
646-
./install/install_requirements.sh
647-
ls -la
648-
pwd
649-
pip3 list
650-
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
651-
echo "::endgroup::"
651+
pip3 uninstall -y torchao || true
652+
set -eou pipefail
652653
653-
echo "::group::Download checkpoints"
654-
(
655-
mkdir -p checkpoints/stories15M
656-
pushd checkpoints/stories15M
657-
curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
658-
curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
659-
popd
660-
)
661-
echo "::endgroup::"
654+
echo "::group::Print machine info"
655+
uname -a
656+
sysctl machdep.cpu.brand_string
657+
sysctl machdep.cpu.core_count
658+
echo "::endgroup::"
662659
663-
echo "::group::Run inference"
664-
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
665-
export MODEL_NAME=stories15M
666-
export MODEL_DIR=/tmp
660+
echo "::group::Install requirements"
661+
# Install requirements
662+
./install/install_requirements.sh
663+
ls -la
664+
pwd
665+
pip3 list
666+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
667+
echo "::endgroup::"
668+
669+
echo "::group::Download checkpoints"
670+
(
671+
mkdir -p checkpoints/stories15M
672+
pushd checkpoints/stories15M
673+
curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
674+
curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
675+
popd
676+
)
677+
echo "::endgroup::"
678+
679+
echo "::group::Run inference"
680+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
681+
export MODEL_NAME=stories15M
682+
export MODEL_DIR=/tmp
667683
668-
python3 torchchat.py generate --device mps --checkpoint-path ${MODEL_PATH} --temperature 0
684+
python3 torchchat.py generate --device mps --checkpoint-path ${MODEL_PATH} --temperature 0
669685
670-
echo "************************************************************"
671-
echo "*** embedding"
672-
echo "************************************************************"
686+
echo "************************************************************"
687+
echo "*** embedding"
688+
echo "************************************************************"
673689
674-
python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
675-
python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
690+
python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
691+
python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
676692
677-
echo "************************************************************"
678-
echo "*** linear int8"
679-
echo "************************************************************"
693+
echo "************************************************************"
694+
echo "*** linear int8"
695+
echo "************************************************************"
680696
681-
python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
682-
python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
697+
python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
698+
python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
683699
684-
echo "************************************************************"
685-
echo "*** linear int4"
686-
echo "************************************************************"
700+
echo "************************************************************"
701+
echo "*** linear int4"
702+
echo "************************************************************"
687703
688-
PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0
704+
PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0
689705
test-gguf-util:
690706
strategy:
691707
matrix:
@@ -734,66 +750,82 @@ jobs:
734750
735751
echo "Tests complete."
736752
test-mps-dtype:
737-
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
738-
with:
739-
runner: macos-m1-stable # needs MPS, was macos-m1-stable
740-
script: |
741-
export PYTHON_VERSION="3.10"
742-
set -x
743-
# NS/MC: Remove previous installation of torch and torchao first
744-
# as this script does not install anything into conda env but rather as system dep
745-
pip3 uninstall -y torch || true
746-
set -eou pipefail
747-
748-
pip3 uninstall -y torchao || true
749-
set -eou pipefail
750-
751-
echo "::group::Print machine info"
752-
uname -a
753-
sysctl machdep.cpu.brand_string
754-
sysctl machdep.cpu.core_count
755-
echo "::endgroup::"
753+
strategy:
754+
matrix:
755+
runner: [macos-m1-stable ]
756+
runs-on: ${{matrix.runner}}
757+
steps:
758+
- name: Checkout repo
759+
uses: actions/checkout@v2
760+
- name: Setup Python
761+
uses: actions/setup-python@v2
762+
with:
763+
python-version: 3.10.11
764+
- name: Print machine info
765+
run: |
766+
uname -a
767+
if [ $(uname -s) == Darwin ]; then
768+
sysctl machdep.cpu.brand_string
769+
sysctl machdep.cpu.core_count
770+
fi
771+
- name: Run test
772+
run: |
773+
export PYTHON_VERSION="3.10"
774+
set -x
775+
# NS/MC: Remove previous installation of torch and torchao first
776+
# as this script does not install anything into conda env but rather as system dep
777+
pip3 uninstall -y torch || true
778+
set -eou pipefail
756779
757-
echo "::group::Install requirements"
758-
# Install requirements
759-
./install/install_requirements.sh
760-
ls -la
761-
pwd
762-
pip3 list
763-
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
764-
echo "::endgroup::"
780+
pip3 uninstall -y torchao || true
781+
set -eou pipefail
765782
766-
echo "::group::Download checkpoints"
767-
(
768-
mkdir -p checkpoints/stories15M
769-
pushd checkpoints/stories15M
770-
curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
771-
curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
772-
popd
773-
)
774-
echo "::endgroup::"
783+
echo "::group::Print machine info"
784+
uname -a
785+
sysctl machdep.cpu.brand_string
786+
sysctl machdep.cpu.core_count
787+
echo "::endgroup::"
775788
776-
echo "::group::Run inference"
777-
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
778-
export MODEL_NAME=stories15M
779-
export MODEL_DIR=/tmp
780-
for DTYPE in float16 float32; do
781-
# if [ $(uname -s) == Darwin ]; then
782-
# export DTYPE=float16
783-
# fi
789+
echo "::group::Install requirements"
790+
# Install requirements
791+
./install/install_requirements.sh
792+
ls -la
793+
pwd
794+
pip3 list
795+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
796+
echo "::endgroup::"
797+
798+
echo "::group::Download checkpoints"
799+
(
800+
mkdir -p checkpoints/stories15M
801+
pushd checkpoints/stories15M
802+
curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
803+
curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
804+
popd
805+
)
806+
echo "::endgroup::"
807+
808+
echo "::group::Run inference"
809+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
810+
export MODEL_NAME=stories15M
811+
export MODEL_DIR=/tmp
812+
for DTYPE in float16 float32; do
813+
# if [ $(uname -s) == Darwin ]; then
814+
# export DTYPE=float16
815+
# fi
784816
785-
python3 torchchat.py generate --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0
817+
python3 torchchat.py generate --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0
786818
787-
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
819+
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
788820
789-
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
821+
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
790822
791-
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
823+
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0
792824
793-
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
825+
python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0
794826
795-
PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0
796-
done
827+
PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0
828+
done
797829
compile-gguf:
798830
strategy:
799831
matrix:
@@ -918,11 +950,11 @@ jobs:
918950
- name: Install ExecuTorch python
919951
run: |
920952
echo "Install ExecuTorch python"
921-
pushd et-build/src/executorch
922-
chmod +x ./install_requirements.sh
923-
chmod +x ./install_requirements.py
924-
./install_requirements.sh
925-
popd
953+
export TORCHCHAT_ROOT=$PWD
954+
export ET_BUILD_DIR="et-build"
955+
ENABLE_ET_PYBIND="${1:-true}"
956+
source "torchchat/utils/scripts/install_utils.sh"
957+
install_executorch_python_libs $ENABLE_ET_PYBIND
926958
- name: Install runner
927959
run: |
928960
echo "Installing runner"
@@ -1067,14 +1099,12 @@ jobs:
10671099
echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV"
10681100
- name: Load or install ET
10691101
id: install-et
1070-
uses: actions/cache@v3
1071-
env:
1072-
cache-key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}
1102+
uses: actions/cache@v4
10731103
with:
1074-
path: ./et-build
1075-
key: ${{env.cache-key}}
1076-
restore-keys: |
1077-
${{env.cache-key}}
1104+
path: |
1105+
./et-build
1106+
./torchchat/utils/scripts
1107+
key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh') }}
10781108
- if: ${{ steps.install-et.outputs.cache-hit != 'true' }}
10791109
continue-on-error: true
10801110
run: |
@@ -1083,11 +1113,11 @@ jobs:
10831113
- name: Install ExecuTorch python
10841114
run: |
10851115
echo "Install ExecuTorch python"
1086-
pushd et-build/src/executorch
1087-
chmod +x ./install_requirements.sh
1088-
chmod +x ./install_requirements.py
1089-
./install_requirements.sh
1090-
popd
1116+
export TORCHCHAT_ROOT=$PWD
1117+
export ET_BUILD_DIR="et-build"
1118+
ENABLE_ET_PYBIND="${1:-true}"
1119+
source "torchchat/utils/scripts/install_utils.sh"
1120+
install_executorch_python_libs $ENABLE_ET_PYBIND
10911121
- name: Install runner
10921122
run: |
10931123
echo "Installing runner"

torchchat/utils/scripts/install_et.sh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,4 @@ pushd ${TORCHCHAT_ROOT}
1919
find_cmake_prefix_path
2020
clone_executorch
2121
install_executorch_libs $ENABLE_ET_PYBIND
22-
install_executorch_python_libs $ENABLE_ET_PYBIND
23-
# TODO: figure out the root cause of 'AttributeError: module 'evaluate'
24-
# has no attribute 'utils'' error from evaluate CI jobs and remove
25-
# `import lm_eval` from torchchat.py since it requires a specific version
26-
# of numpy.
27-
pip install numpy=='1.26.4'
2822
popd

torchchat/utils/scripts/install_utils.sh

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ install_executorch_python_libs() {
9393
echo "Installing pybind"
9494
bash ./install_requirements.sh --pybind xnnpack
9595
fi
96+
97+
# TODO: figure out the root cause of 'AttributeError: module 'evaluate'
98+
# has no attribute 'utils'' error from evaluate CI jobs and remove
99+
# `import lm_eval` from torchchat.py since it requires a specific version
100+
# of numpy.
101+
pip install numpy=='1.26.4'
102+
96103
pip3 list
97104
popd
98105
}
@@ -169,10 +176,9 @@ clone_torchao() {
169176
pushd ${TORCHCHAT_ROOT}/torchao-build/src
170177
echo $pwd
171178

172-
cp -R $HOME/fbsource/fbcode/pytorch/ao .
173-
# git clone https://github.com/pytorch/ao.git
174-
# cd ao
175-
# git checkout $(cat ${TORCHCHAT_ROOT}/intstall/.pins/torchao-pin.txt)
179+
git clone https://github.com/pytorch/ao.git
180+
cd ao
181+
git checkout $(cat ${TORCHCHAT_ROOT}/install/.pins/torchao-pin.txt)
176182

177183
popd
178184
}

0 commit comments

Comments
 (0)