Skip to content

Commit c96965d

Browse files
ryan-williamsclaude
andcommitted
Add GitHub Actions workflows for GPU testing on EC2
- test.yaml: Reusable workflow that provisions EC2 GPU instances and runs pytest - Supports g5 (A10G) and g6 (L4) instance types - Uses Deep Learning AMI with pre-installed PyTorch - Configures TORCH_CUDA_ARCH_LIST for fast targeted builds - Runs tests with --maxfail=10 to gather more failure data - tests.yaml: Main workflow that runs tests on multiple GPU types - Tests on both g5.2xlarge (A10G) and g6.2xlarge (L4) in parallel - Triggered on push/PR to main or manual dispatch 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 2beeed6 commit c96965d

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

.github/workflows/test.yaml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
name: GPU tests
2+
on:
3+
workflow_dispatch:
4+
inputs:
5+
instance_type:
6+
description: 'EC2 instance type'
7+
required: false
8+
type: choice
9+
default: 'g6.2xlarge'
10+
options:
11+
- g5.xlarge # 4 vCPUs, 16GB RAM, A10G GPU, ≈$1.11/hr
12+
- g5.2xlarge # 8 vCPUs, 32GB RAM, A10G GPU, ≈$1.33/hr
13+
- g5.4xlarge # 16 vCPUs, 64GB RAM, A10G GPU, ≈$1.79/hr
14+
- g6.xlarge # 4 vCPUs, 16GB RAM, L4 GPU, ≈$0.89/hr
15+
- g6.2xlarge # 8 vCPUs, 32GB RAM, L4 GPU, ≈$1.08/hr
16+
- g6.4xlarge # 16 vCPUs, 64GB RAM, L4 GPU, ≈$1.46/hr
17+
workflow_call:
18+
inputs:
19+
instance_type:
20+
description: 'EC2 instance type'
21+
required: true
22+
type: string
23+
permissions:
24+
id-token: write
25+
contents: read
26+
jobs:
27+
ec2:
28+
name: Start EC2 runner
29+
uses: Open-Athena/ec2-gha/.github/workflows/runner.yml@v2
30+
with:
31+
ec2_instance_type: ${{ inputs.instance_type || 'g6.2xlarge' }}
32+
ec2_image_id: ami-0aee7b90d684e107d # Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.4.1 (Ubuntu 22.04) 20250623
33+
secrets:
34+
GH_SA_TOKEN: ${{ secrets.GH_SA_TOKEN }}
35+
test:
36+
name: GPU tests
37+
needs: ec2
38+
runs-on: ${{ needs.ec2.outputs.id }}
39+
steps:
40+
- uses: actions/checkout@v4
41+
- name: Setup Python environment
42+
run: |
43+
# Use the DLAMI's pre-installed PyTorch conda environment
44+
echo "/opt/conda/envs/pytorch/bin" >> $GITHUB_PATH
45+
echo "CONDA_DEFAULT_ENV=pytorch" >> $GITHUB_ENV
46+
- name: Check GPU
47+
run: nvidia-smi
48+
- name: Install mamba-ssm and test dependencies
49+
run: |
50+
# Use all available CPUs for compilation (we're only building for 1 GPU arch)
51+
export MAX_JOBS=$(nproc)
52+
53+
INSTANCE_TYPE="${{ inputs.instance_type || 'g6.2xlarge' }}"
54+
55+
# Set CUDA architecture based on GPU type
56+
# TORCH_CUDA_ARCH_LIST tells PyTorch which specific architecture to compile for
57+
if [[ "$INSTANCE_TYPE" == g5.* ]]; then
58+
export TORCH_CUDA_ARCH_LIST="8.6" # A10G GPU
59+
export CUDA_VISIBLE_DEVICES=0
60+
export NVCC_GENCODE="-gencode arch=compute_86,code=sm_86"
61+
elif [[ "$INSTANCE_TYPE" == g6.* ]]; then
62+
export TORCH_CUDA_ARCH_LIST="8.9" # L4 GPU (Ada Lovelace)
63+
export CUDA_VISIBLE_DEVICES=0
64+
export NVCC_GENCODE="-gencode arch=compute_89,code=sm_89"
65+
fi
66+
67+
echo "Building with MAX_JOBS=$MAX_JOBS for $INSTANCE_TYPE"
68+
69+
# Install mamba-ssm with causal-conv1d and dev dependencies
70+
# Note: causal-conv1d will download pre-built wheels when available
71+
pip install -v --no-build-isolation -e .[causal-conv1d,dev]
72+
- name: Run tests
73+
run: pytest -vs --maxfail=10 tests/

.github/workflows/tests.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: GPU tests on multiple instance types
2+
on:
3+
push:
4+
branches: [main]
5+
pull_request:
6+
branches: [main]
7+
workflow_dispatch:
8+
9+
permissions:
10+
id-token: write
11+
contents: read
12+
13+
jobs:
14+
test-g5:
15+
name: Test on g5.2xlarge (A10G)
16+
uses: ./.github/workflows/test.yaml
17+
with:
18+
instance_type: g5.2xlarge
19+
secrets: inherit
20+
21+
test-g6:
22+
name: Test on g6.2xlarge (L4)
23+
uses: ./.github/workflows/test.yaml
24+
with:
25+
instance_type: g6.2xlarge
26+
secrets: inherit

0 commit comments

Comments
 (0)