Skip to content

Commit 06458a0

Browse files
Upgrade to CUDA 12 (#1527)
Co-authored-by: Woosuk Kwon <[email protected]>
1 parent 1a2bbc9 commit 06458a0

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

.github/workflows/publish.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ jobs:
4343
name: Build Wheel
4444
runs-on: ${{ matrix.os }}
4545
needs: release
46-
46+
4747
strategy:
4848
fail-fast: false
4949
matrix:
5050
os: ['ubuntu-20.04']
5151
python-version: ['3.8', '3.9', '3.10', '3.11']
52-
pytorch-version: ['2.0.1']
53-
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
52+
pytorch-version: ['2.1.0']
53+
cuda-version: ['12.1']
5454

5555
steps:
5656
- name: Checkout
@@ -82,7 +82,7 @@ jobs:
8282
asset_name=${wheel_name//"linux"/"manylinux1"}
8383
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
8484
echo "asset_name=${asset_name}" >> $GITHUB_ENV
85-
85+
8686
- name: Upload Release Asset
8787
uses: actions/upload-release-asset@v1
8888
env:

.github/workflows/scripts/build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,8 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
1111
$python_executable -m pip install wheel packaging
1212
$python_executable -m pip install -r requirements.txt
1313

14+
# Limit the number of parallel jobs to avoid OOM
15+
export MAX_JOBS=1
16+
1417
# Build
1518
$python_executable setup.py bdist_wheel --dist-dir=dist

.github/workflows/scripts/cuda-install.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ sudo apt clean
1616
# Test nvcc
1717
PATH=/usr/local/cuda-$1/bin:${PATH}
1818
nvcc --version
19+
20+
# Log gcc, g++, c++ versions
21+
gcc --version
22+
g++ --version
23+
c++ --version

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ requires = [
33
"ninja",
44
"packaging",
55
"setuptools",
6-
"torch == 2.0.1",
6+
"torch >= 2.1.0",
77
"wheel",
88
]
99
build-backend = "setuptools.build_meta"

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ pandas # Required for Ray data.
55
pyarrow # Required for Ray data.
66
sentencepiece # Required for LLaMA tokenizer.
77
numpy
8-
torch == 2.0.1
8+
torch >= 2.1.0
99
transformers >= 4.34.0 # Required for Mistral.
10-
xformers == 0.0.22 # Required for Mistral.
10+
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
1111
fastapi
1212
uvicorn[standard]
1313
pydantic == 1.10.13 # Required for OpenAI server.

0 commit comments

Comments
 (0)