Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
69 changes: 67 additions & 2 deletions .github/workflows/more-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
gpu-arch-version: "12.4"
timeout: 60
script: |
set -xeou pipefail
echo "::group::Print machine info"
uname -a
echo "::endgroup::"
Expand All @@ -39,9 +40,10 @@ jobs:
echo "::endgroup::"
echo "::group::Run inference"
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_DIR=checkpoints/stories15M/
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp
for DTYPE in bfloat16 float16 float32; do
###################################################################
Expand Down Expand Up @@ -83,3 +85,66 @@ jobs:
echo "tests complete"
echo "******************************************"
echo "::endgroup::"
test-sdpa-backends-export:
permissions:
id-token: write
contents: read
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.4"
timeout: 60
script: |
set -xeou pipefail
echo "::group::Print machine info"
uname -a
echo "::endgroup::"
echo "::group::Download checkpoints"
# Install requirements
./install/install_requirements.sh cuda
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
echo "::endgroup::"
echo "::group::Download checkpoints"
mkdir -p checkpoints/stories15M
pushd checkpoints/stories15M
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
popd
echo "::endgroup::"
echo "::group::Run inference"
export MODEL_DIR=checkpoints/stories15M/
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
export MODEL_NAME=stories15M
./torchchat/utils/scripts/build_native.sh aoti
for DEVICE in cpu cuda; do
# depending on how the parameter passing works, may only be able to do bfloat16 for aoti_run, similar to runner-cuda-dtype.yml
# (although the runner environment should not have an opinion what we us in the artifact, and we might suitably abstract that)
for DTYPE in bfloat16 float16 float32; do
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
echo "***************************************************************"
echo "*** $DEVICE $DTYPE $SDPA"
###################################################################
# Export DSO and run with Python
python torchchat.py export --output-dso dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
python torchchat.py generate --dso-path dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --prompt "Once upon a time"
###################################################################
# Export AOTI and run with aoti_run
python torchchat.py export --output-aoti /tmp/model.pt2 --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "Once upon a time"
###################################################################
done
done
done
echo "tests complete"
echo "******************************************"
echo "::endgroup::"
33 changes: 17 additions & 16 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,13 +490,14 @@ def main(args):
print(
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
)
export_for_server(
model_to_dso,
builder_args.device,
output_dso_path,
builder_args.dynamic_shapes,
package=False,
)
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
export_for_server(
model_to_dso,
builder_args.device,
output_dso_path,
builder_args.dynamic_shapes,
package=False,
)

if output_aoti_package_path:
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
Expand All @@ -512,14 +513,15 @@ def main(args):
print(
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
)
export_for_server(
model_to_aoti_package,
builder_args.device,
output_aoti_package_path,
builder_args.dynamic_shapes,
package=True,
metadata=metadata,
)
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
export_for_server(
model_to_aoti_package,
builder_args.device,
output_aoti_package_path,
builder_args.dynamic_shapes,
package=True,
metadata=metadata,
)

if output_snapshot_path:
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
Expand All @@ -529,4 +531,3 @@ def main(args):
builder_args.device,
output_snapshot_path,
)

Loading