Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3c3b367

Browse files
authored
Update more-tests.yml
Add tests for backends
1 parent 8e18e7f commit 3c3b367

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

.github/workflows/more-tests.yml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,61 @@ jobs:
8383
echo "tests complete"
8484
echo "******************************************"
8585
echo "::endgroup::"
86+
87+
test-sdpa-backends:
88+
permissions:
89+
id-token: write
90+
contents: read
91+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
92+
with:
93+
runner: linux.g5.4xlarge.nvidia.gpu
94+
gpu-arch-type: cuda
95+
gpu-arch-version: "12.4"
96+
timeout: 60
97+
script: |
98+
echo "::group::Print machine info"
99+
uname -a
100+
echo "::endgroup::"
101+
102+
echo "::group::Download checkpoints"
103+
# Install requirements
104+
./install/install_requirements.sh cuda
105+
pip3 list
106+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
107+
echo "::endgroup::"
108+
109+
echo "::group::Download checkpoints"
110+
mkdir -p checkpoints/stories15M
111+
pushd checkpoints/stories15M
112+
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
113+
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
114+
popd
115+
echo "::endgroup::"
116+
117+
echo "::group::Run inference"
118+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
119+
export MODEL_NAME=stories15M
120+
export MODEL_DIR=/tmp
121+
122+
for DEVICE in cpu cuda; do
123+
for DTYPE in bfloat16 float16 float32; do
124+
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
125+
###################################################################
126+
# Python execution interpreted vanilla
127+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0
128+
###################################################################
129+
# prefill, and compile and prefill compile
130+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --compile --compile-prefill
131+
###################################################################
132+
# sequential prefill
133+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill
134+
###################################################################
135+
# prefill, and compile
136+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill --compile
137+
done
138+
done
139+
done
140+
141+
echo "tests complete"
142+
echo "******************************************"
143+
echo "::endgroup::"

0 commit comments

Comments
 (0)