Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 30cd5b7

Browse files
authored
Update more-tests.yml
Add tests for sdpa backends with server export (x86 cpu & cuda)
1 parent 1ded204 commit 30cd5b7

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

.github/workflows/more-tests.yml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,61 @@ jobs:
8383
echo "tests complete"
8484
echo "******************************************"
8585
echo "::endgroup::"
86+
87+
88+
test-sdpa-backends-export:
89+
permissions:
90+
id-token: write
91+
contents: read
92+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
93+
with:
94+
runner: linux.g5.4xlarge.nvidia.gpu
95+
gpu-arch-type: cuda
96+
gpu-arch-version: "12.4"
97+
timeout: 60
98+
script: |
99+
echo "::group::Print machine info"
100+
uname -a
101+
echo "::endgroup::"
102+
103+
echo "::group::Download checkpoints"
104+
# Install requirements
105+
./install/install_requirements.sh cuda
106+
pip3 list
107+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
108+
echo "::endgroup::"
109+
110+
echo "::group::Download checkpoints"
111+
mkdir -p checkpoints/stories15M
112+
pushd checkpoints/stories15M
113+
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
114+
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
115+
popd
116+
echo "::endgroup::"
117+
118+
echo "::group::Run inference"
119+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
120+
export MODEL_NAME=stories15M
121+
export MODEL_DIR=/tmp
122+
123+
for DEVICE in cpu cuda; do
124+
# depending on how the parameter passing works, may only be able to do bfloat16 for aoti_run, similar to runner-cuda-dtype.yml
125+
# (although the runner environment should not have an opinion what we us in the artifact, and we might suitably abstract that)
126+
for DTYPE in bfloat16 float16 float32; do
127+
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
128+
###################################################################
129+
# Export DSO and run with Python
130+
python torchchat.py export --output-dso dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0
131+
python torchchat.py generate --dso-path dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --prompt "Once upon a time"
132+
###################################################################
133+
# Export AOTI and run with aoti_run
134+
python torchchat.py export --output-aoti /tmp/model.pt2 --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0
135+
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "Once upon a time"
136+
###################################################################
137+
done
138+
done
139+
done
140+
141+
echo "tests complete"
142+
echo "******************************************"
143+
echo "::endgroup::"

0 commit comments

Comments
 (0)